diff --git a/.github/workflows/_build.yml b/.github/workflows/_build.yml new file mode 100644 index 00000000000..8c529583c72 --- /dev/null +++ b/.github/workflows/_build.yml @@ -0,0 +1,225 @@ +name: ~Build wheel template + +on: + workflow_call: + inputs: + runs-on: + description: "The runner to use for the build" + required: true + type: string + python-version: + description: "The Python version to use for the build" + required: true + type: string + cuda-version: + description: "The CUDA version to use for the build" + required: true + type: string + torch-version: + description: "The PyTorch version to use for the build" + required: true + type: string + cxx11_abi: + description: "The C++11 ABI to use for the build" + required: true + type: string + upload-to-release: + description: "Upload wheel to this release" + required: false + type: boolean + default: false + release-version: + description: "Upload wheel to this release" + required: false + type: string + +defaults: + run: + shell: bash -x -e -u -o pipefail {0} + +jobs: + build-wheel: + runs-on: ${{ inputs.runs-on }} + name: Build wheel (${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}) + steps: + - name: Checkout + uses: actions/checkout@v5 + with: + ref: ${{ inputs.release-version }} + submodules: recursive + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python-version }} + + - name: Set CUDA and PyTorch versions + run: | + echo "MATRIX_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV + echo "MATRIX_TORCH_VERSION=$(echo ${{ inputs.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV + echo "WHEEL_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV + echo "MATRIX_PYTHON_VERSION=$(echo ${{ inputs.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV + + - name: Free up disk space + if: ${{ runner.os == 'Linux' }} + # https://github.com/easimon/maximize-build-space/blob/master/action.yml + # https://github.com/easimon/maximize-build-space/tree/test-report + run: | + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf /opt/hostedtoolcache/CodeQL + + - name: Set up swap space + if: runner.os == 'Linux' + uses: pierotofy/set-swap-space@v1.0 + with: + swap-size-gb: 10 + + - name: Install CUDA ${{ inputs.cuda-version }} + if: ${{ inputs.cuda-version != 'cpu' }} + uses: Jimver/cuda-toolkit@v0.2.29 + id: cuda-toolkit + with: + cuda: ${{ inputs.cuda-version }} + linux-local-args: '["--toolkit"]' + # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1 + # method: ${{ (inputs.cuda-version == '11.8.0' || inputs.cuda-version == '12.1.0') && 'network' || 'local' }} + method: "network" + sub-packages: '["nvcc"]' + + - name: Install PyTorch ${{ inputs.torch-version }}+cu${{ inputs.cuda-version }} + run: | + pip install --upgrade pip + # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error + # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable + pip install typing-extensions==4.12.2 + # We want to figure out the CUDA version to download pytorch + # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 + # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix + # This code is ugly, maybe there's a better way to do this. + export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ + minv = {'2.5': 118, '2.6': 118, '2.7': 118, '2.8': 126, '2.9': 126}[env['MATRIX_TORCH_VERSION']]; \ + maxv = {'2.5': 124, '2.6': 126, '2.7': 128, '2.8': 129, '2.9': 130}[env['MATRIX_TORCH_VERSION']]; \ + print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ + ) + # detect if we're on ARM + if [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then + PLAT=linux_aarch64 + else + PLAT=manylinux_2_27_x86_64.manylinux_2_28_x86_64 + fi + echo "PLAT=$PLAT" >> $GITHUB_ENV + if [[ ${{ inputs.torch-version }} == *"dev"* ]]; then + # pip install --no-cache-dir --pre torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} + # Can't use --no-deps because we need cudnn etc. + # Hard-coding this version of pytorch-triton for torch 2.9.0.dev20250904 + pip install jinja2 + TRITON_URL=https://download.pytorch.org/whl/nightly/pytorch_triton-3.4.0%2Bgitf7888497-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-${PLAT}.whl + TORCH_URL=https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ inputs.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_28_$(uname -m).whl + pip install --no-cache-dir --pre "${TRITON_URL}" + pip install --no-cache-dir --pre "${TORCH_URL}" + else + pip install --no-cache-dir torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} + fi + nvcc --version + python --version + python -c "import torch; print('PyTorch:', torch.__version__)" + python -c "import torch; print('CUDA:', torch.version.cuda)" + python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" + + - name: Restore build cache + uses: actions/cache/restore@v4 + with: + path: build.tar + key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }} + restore-keys: | + build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}- + + - name: Unpack build cache + run: | + echo ::group::Adjust timestamps + sudo find / -exec touch -t 197001010000 {} + || true + echo ::endgroup:: + + if [ -f build.tar ]; then + find . -mindepth 1 -maxdepth 1 ! -name 'build.tar' -exec rm -rf {} + + tar -xpvf build.tar -C . + else + echo "No build.tar found, skipping" + fi + + ls -al ./ + ls -al build/ || true + ls -al csrc/ || true + + - name: Build wheel + id: build_wheel + run: | + # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 + # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 + # However this still fails so I'm using a newer version of setuptools + pip install setuptools==75.8.0 + pip install ninja packaging wheel + export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH + export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH + # Limit MAX_JOBS otherwise the github runner goes OOM + # nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM + + export MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "129" ] && echo 1 || echo 2) + export NVCC_THREADS=2 + export FLASH_ATTENTION_FORCE_BUILD="TRUE" + export FLASH_ATTENTION_FORCE_CXX11_ABI=${{ inputs.cxx11_abi }} + + # 5h timeout since GH allows max 6h and we want some buffer + EXIT_CODE=0 + timeout 5h python setup.py bdist_wheel --dist-dir=dist || EXIT_CODE=$? + + if [ $EXIT_CODE -eq 0 ]; then + tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }} + wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") + ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} + echo "wheel_name=${wheel_name}" >> $GITHUB_ENV + fi + + # Store exit code in GitHub env for later steps + echo "build_exit_code=$EXIT_CODE" | tee -a "$GITHUB_OUTPUT" + + # Do not fail the job if timeout killed the build + exit $EXIT_CODE + + - name: Log build logs after timeout + if: always() && steps.build_wheel.outputs.build_exit_code == 124 + run: | + ls -al ./ + tar -cvf build.tar . --atime-preserve=replace + + - name: Save build cache timeout + if: always() && steps.build_wheel.outputs.build_exit_code == 124 + uses: actions/cache/save@v4 + with: + key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }} + path: build.tar + + - name: Log Built Wheels + run: | + ls dist + + - name: Get Release with tag + id: get_current_release + uses: joutvhu/get-release@v1 + with: + tag_name: ${{ inputs.release-version }} + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Upload Release Asset + id: upload_release_asset + if: inputs.upload-to-release + uses: actions/upload-release-asset@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + upload_url: ${{ steps.get_current_release.outputs.upload_url }} + asset_path: ./dist/${{env.wheel_name}} + asset_name: ${{env.wheel_name}} + asset_content_type: application/* diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 00000000000..25ea5e86b75 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,47 @@ +name: Build wheels + +on: + workflow_dispatch: + inputs: + runs-on: + description: "The runner to use for the build" + required: true + type: string + default: ubuntu-22.04 + python-version: + description: "The Python version to use for the build" + required: true + type: string + cuda-version: + description: "The CUDA version to use for the build" + required: true + type: string + torch-version: + description: "The PyTorch version to use for the build" + required: true + type: string + cxx11_abi: + description: "Enable torch flag C++11 ABI (TRUE/FALSE)" + required: true + type: string + upload-to-release: + description: "Upload wheel to this release" + required: false + type: boolean + default: false + release-version: + description: "Upload wheel to this release" + required: false + type: string + +jobs: + build-wheels: + uses: ./.github/workflows/_build.yml + with: + runs-on: ${{ inputs.runs-on }} + python-version: ${{ inputs.python-version }} + cuda-version: ${{ inputs.cuda-version }} + torch-version: ${{ inputs.torch-version }} + cxx11_abi: ${{ inputs.cxx11_abi }} + upload-to-release: ${{ inputs.upload-to-release }} + release-version: ${{ inputs.release-version }} diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml new file mode 100644 index 00000000000..bc304a5641a --- /dev/null +++ b/.github/workflows/pre-commit.yaml @@ -0,0 +1,33 @@ +name: Lint + +on: + pull_request: + paths: + - 'flash_attn/cute/flash_bwd_sm90.py' + - 'flash_attn/cute/flash_bwd_preprocess.py' + - 'flash_attn/cute/flash_bwd_postprocess.py' + - 'flash_attn/cute/softmax.py' + - '.pre-commit-config.yaml' + push: + branches: + - main + paths: + - 'flash_attn/cute/flash_bwd_sm90.py' + - 'flash_attn/cute/flash_bwd_preprocess.py' + - 'flash_attn/cute/flash_bwd_postprocess.py' + - 'flash_attn/cute/softmax.py' + - '.pre-commit-config.yaml' + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: '3.11' + + - name: Run pre-commit + uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 6f227d1abe1..d2bc31ed119 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -13,16 +13,16 @@ on: - v* jobs: - setup_release: name: Create Release runs-on: ubuntu-latest + outputs: + release-version: ${{ steps.extract_branch.outputs.branch }} steps: - name: Get the tag version id: extract_branch run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} shell: bash - - name: Create Release id: create_release uses: actions/create-release@v1 @@ -35,167 +35,50 @@ jobs: build_wheels: name: Build Wheel needs: setup_release - runs-on: ${{ matrix.os }} - strategy: fail-fast: false matrix: - # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the - # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. - os: [ubuntu-20.04] - python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] - torch-version: ['2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0'] - cuda-version: ['12.4.1'] - # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. - # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. - # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) - # when building without C++11 ABI and using it on nvcr images. - cxx11_abi: ['FALSE', 'TRUE'] - exclude: - # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix - # Pytorch < 2.5 does not support Python 3.13 - - torch-version: '2.2.2' - python-version: '3.13' - - torch-version: '2.3.1' - python-version: '3.13' - - torch-version: '2.4.0' - python-version: '3.13' - - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - - name: Set CUDA and PyTorch versions - run: | - echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV - echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV - echo "WHEEL_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV - echo "MATRIX_PYTHON_VERSION=$(echo ${{ matrix.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV - - - name: Free up disk space - if: ${{ runner.os == 'Linux' }} - # https://github.com/easimon/maximize-build-space/blob/master/action.yml - # https://github.com/easimon/maximize-build-space/tree/test-report - run: | - sudo rm -rf /usr/share/dotnet - sudo rm -rf /opt/ghc - sudo rm -rf /opt/hostedtoolcache/CodeQL - - - name: Set up swap space - if: runner.os == 'Linux' - uses: pierotofy/set-swap-space@v1.0 - with: - swap-size-gb: 10 - - - name: Install CUDA ${{ matrix.cuda-version }} - if: ${{ matrix.cuda-version != 'cpu' }} - uses: Jimver/cuda-toolkit@v0.2.19 - id: cuda-toolkit - with: - cuda: ${{ matrix.cuda-version }} - linux-local-args: '["--toolkit"]' - # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1 - # method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }} - method: 'network' - sub-packages: '["nvcc"]' - - - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }} - run: | - pip install --upgrade pip - # For some reason torch 2.2.0 on python 3.12 errors saying no setuptools - pip install setuptools==75.8.0 - # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error - # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable - pip install typing-extensions==4.12.2 - # We want to figure out the CUDA version to download pytorch - # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 + # Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the + # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. + os: [ubuntu-22.04, ubuntu-22.04-arm] + python-version: ["3.10", "3.11", "3.12", "3.13"] + torch-version: ["2.5.1", "2.6.0", "2.7.1", "2.8.0", "2.9.1"] + cuda-version: ["12.9.1"] + # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. + # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. + # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) + # when building without C++11 ABI and using it on nvcr images. + cxx11_abi: ["FALSE", "TRUE"] + include: + - torch-version: "2.9.1" + cuda-version: "13.0.2" + python-version: "3.14" + - torch-version: "2.10.0.dev20251108" + cuda-version: "13.0.2" + exclude: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix - # This code is ugly, maybe there's a better way to do this. - export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ - minv = {'2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \ - maxv = {'2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \ - print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ - ) - if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then - # pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} - # Can't use --no-deps because we need cudnn etc. - # Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001 - pip install jinja2 - pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl - pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl - else - pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} - fi - nvcc --version - python --version - python -c "import torch; print('PyTorch:', torch.__version__)" - python -c "import torch; print('CUDA:', torch.version.cuda)" - python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" - shell: - bash - - - name: Build wheel - run: | - # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 - # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 - # However this still fails so I'm using a newer version of setuptools - pip install setuptools==75.8.0 - pip install ninja packaging wheel - export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH - export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH - # Limit MAX_JOBS otherwise the github runner goes OOM - # nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM - MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "123" ] && echo 1 || echo 2) FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist - tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }} - wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") - ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} - echo "wheel_name=${wheel_name}" >> $GITHUB_ENV - - - name: Log Built Wheels - run: | - ls dist - - - name: Get the tag version - id: extract_branch - run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} - - - name: Get Release with tag - id: get_current_release - uses: joutvhu/get-release@v1 - with: - tag_name: ${{ steps.extract_branch.outputs.branch }} - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - - name: Upload Release Asset - id: upload_release_asset - uses: actions/upload-release-asset@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.get_current_release.outputs.upload_url }} - asset_path: ./dist/${{env.wheel_name}} - asset_name: ${{env.wheel_name}} - asset_content_type: application/* + # Pytorch < 2.5 does not support Python 3.13 + - torch-version: "2.4.0" + python-version: "3.13" + uses: ./.github/workflows/_build.yml + with: + runs-on: ${{ matrix.os }} + python-version: ${{ matrix.python-version }} + cuda-version: ${{ matrix.cuda-version }} + torch-version: ${{ matrix.torch-version }} + cxx11_abi: ${{ matrix.cxx11_abi }} + release-version: ${{ needs.setup_release.outputs.release-version }} + upload-to-release: true publish_package: name: Publish package needs: [build_wheels] - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - uses: actions/setup-python@v5 + - uses: actions/checkout@v5 + - uses: actions/setup-python@v6 with: - python-version: '3.10' - + python-version: "3.10" - name: Install dependencies run: | pip install ninja packaging wheel twine diff --git a/.gitignore b/.gitignore index 1f1f8028863..dc508654045 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ *.ncu-rep .DS_store +.vscode # Byte-compiled / optimized / DLL files __pycache__/ @@ -26,6 +27,10 @@ var/ # IDE-related .idea/ +.vscode/ # Dev venv + +# compile-time generated file +flash_attn_config.py \ No newline at end of file diff --git a/.gitmodules b/.gitmodules index 6216182e721..a6446cc597a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,4 @@ [submodule "csrc/composable_kernel"] path = csrc/composable_kernel url = https://github.com/ROCm/composable_kernel.git + branch = amd-master diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000000..6118dfa2283 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,17 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.13 + hooks: + - id: ruff-check + args: [--fix, --exit-non-zero-on-fix] + files: ^flash_attn/cute/.*\.py$ + exclude: &cute_exclude | + (?x)^flash_attn/cute/( + flash_bwd| + flash_fwd| + flash_fwd_sm100| + interface| + )\.py$ + - id: ruff-format + files: ^flash_attn/cute/.*\.py$ + exclude: *cute_exclude diff --git a/README.md b/README.md old mode 100644 new mode 100755 index 043a972a210..96f46519721 --- a/README.md +++ b/README.md @@ -4,7 +4,541 @@ This is a fork of https://github.com/Dao-AILab/flash-attention customized for vL We have the following customizations: -- Build: Cmake, torch library (this package is bundled into vLLM). -- Size: reduced templating and removal of (training) kernels -- Features: Small page size support (FA2), DCP support (FA3) -- Performance: Some decode specific optimizations for sizes we care about; as well as mixed batch performance optimizations. (Upstream is understandably hesitant on specializing for inference as they also need to support training; we on the other hand compile out the backward pass kernels and do not test that our optimizations do not break them.) +Paper: https://tridao.me/publications/flash2/flash2.pdf + +![FlashAttention-2](assets/flashattention_logo.png) + + +## Usage + +We've been very happy to see FlashAttention being widely adopted in such a short +time after its release. This [page](https://github.com/Dao-AILab/flash-attention/blob/main/usage.md) +contains a partial list of places where FlashAttention is being used. + +FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE). +Please cite and credit FlashAttention if you use it. + + +## FlashAttention-3 beta release +FlashAttention-3 is optimized for Hopper GPUs (e.g. H100). + +Blogpost: https://tridao.me/blog/2024/flash3/ + +Paper: https://tridao.me/publications/flash3/flash3.pdf + +![FlashAttention-3 speedup on H100 80GB SXM5 with FP16](assets/flash3_fp16_fwd.png) + +This is a beta release for testing / benchmarking before we integrate that with +the rest of the repo. + +Currently released: +- FP16 / BF16 forward and backward, FP8 forward + +Requirements: H100 / H800 GPU, CUDA >= 12.3. + +We highly recommend CUDA 12.8 for best performance. + +To install: +```sh +cd hopper +python setup.py install +``` +To run the test: +```sh +export PYTHONPATH=$PWD +pytest -q -s test_flash_attn.py +``` +Once the package is installed, you can import it as follows: +```python +import flash_attn_interface +flash_attn_interface.flash_attn_func() +``` + +## Installation and features +**Requirements:** +- CUDA toolkit or ROCm toolkit +- PyTorch 2.2 and above. +- `packaging` Python package (`pip install packaging`) +- `ninja` Python package (`pip install ninja`) * +- Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue. + +\* Make sure that `ninja` is installed and that it works correctly (e.g. `ninja +--version` then `echo $?` should return exit code 0). If not (sometimes `ninja +--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall +`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`, +compiling can take a very long time (2h) since it does not use multiple CPU +cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine using CUDA toolkit. + +**To install:** +```sh +pip install flash-attn --no-build-isolation +``` +Alternatively you can compile from source: +```sh +python setup.py install +``` + +If your machine has less than 96GB of RAM and lots of CPU cores, `ninja` might +run too many parallel compilation jobs that could exhaust the amount of RAM. To +limit the number of parallel compilation jobs, you can set the environment +variable `MAX_JOBS`: +```sh +MAX_JOBS=4 pip install flash-attn --no-build-isolation +``` + +**Interface:** `src/flash_attention_interface.py` + +### NVIDIA CUDA Support +**Requirements:** +- CUDA 12.0 and above. + +We recommend the +[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) +container from Nvidia, which has all the required tools to install FlashAttention. + +FlashAttention-2 with CUDA currently supports: +1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing + GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing + GPUs for now. +2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs). +3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5. + +### AMD ROCm Support +ROCm version has two backends. There is [composable_kernel](https://github.com/ROCm/composable_kernel) (ck) which is the default backend and a [Triton](https://github.com/triton-lang/triton) backend. They provide an implementation of FlashAttention-2. + +**Requirements:** +- ROCm 6.0 and above. + +We recommend the +[Pytorch](https://hub.docker.com/r/rocm/pytorch) +container from ROCm, which has all the required tools to install FlashAttention. + +#### Composable Kernel Backend +FlashAttention-2 ROCm CK backend currently supports: +1. MI200x, MI250x, MI300x, and MI355x GPUs. +2. Datatype fp16 and bf16 +3. Both forward's and backward's head dimensions up to 256. + +#### Triton Backend +The Triton implementation of the [Flash Attention v2](https://tridao.me/publications/flash2/flash2.pdf) is currently a work in progress. + +It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16, bf16 and fp32 datatypes. + +These features are supported in Fwd and Bwd +1) Fwd and Bwd with causal masking +2) Variable sequence lengths +3) Arbitrary Q and KV sequence lengths +4) Arbitrary head sizes +5) Multi and grouped query attention +6) Dropout +7) Rotary embeddings +8) ALiBi + +We are working on the following things +1) Paged Attention +2) Sliding Window +3) FP8 +4) Performance Improvements + +##### Getting Started +To get started with the triton backend for AMD, follow the steps below. + +First install the torch for ROCm from https://pytorch.org/get-started/locally/ if it is not installed. The torch and triton will be installed. + +Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`. + +``` +cd flash-attention +FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install +``` + +To test that things are working, you can run our tests. These tests take hours so you don't need to run the full thing. +``` +FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py +``` + +You can use autotune for better performance by using this flag `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"` +``` +FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE" python $PATH_TO_CODE +``` + +###### Docker +You can also use the Dockerfile below which does the above steps on top of the latest rocm/pytorch image. +``` +FROM rocm/pytorch:latest + +WORKDIR /workspace + +# install flash attention +ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + +RUN git clone https://github.com/ROCm/flash-attention.git &&\ + cd flash-attention &&\ + python setup.py install + +# set working dir +WORKDIR /workspace/flash-attention +``` + +To build the docker file +``` +docker build -t fa_triton . +``` + +To run the docker image +``` +docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri fa_triton +``` + +## How to use FlashAttention + +The main functions implement scaled dot product attention (softmax(Q @ K^T * +softmax_scale) @ V): +```python +from flash_attn import flash_attn_qkvpacked_func, flash_attn_func +``` + +```python +flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, + window_size=(-1, -1), alibi_slopes=None, deterministic=False): +"""dropout_p should be set to 0.0 during evaluation +If Q, K, V are already stacked into 1 tensor, this function will be faster than +calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation +of the gradients of Q, K, V. +If window_size != (-1, -1), implements sliding window local attention. Query at position i +will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. +Arguments: + qkv: (batch_size, seqlen, 3, nheads, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to + the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. +Return: + out: (batch_size, seqlen, nheads, headdim). +""" +``` + +```python +flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, + window_size=(-1, -1), alibi_slopes=None, deterministic=False): +"""dropout_p should be set to 0.0 during evaluation +Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads +than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. +For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head +0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. +If window_size != (-1, -1), implements sliding window local attention. Query at position i +will only attend to keys between +[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + +Arguments: + q: (batch_size, seqlen, nheads, headdim) + k: (batch_size, seqlen, nheads_k, headdim) + v: (batch_size, seqlen, nheads_k, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. +Return: + out: (batch_size, seqlen, nheads, headdim). +""" +``` + +```python +def flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=None, + v=None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + block_table: Optional[torch.Tensor] = None, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + rotary_interleaved=True, + alibi_slopes=None, +): + """ + If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from + k and v. This is useful for incremental decoding: you can pass in the cached keys/values from + the previous step, and update them with the new keys/values from the current step, and do + attention with the updated cache, all in 1 kernel. + + If you pass in k / v, you must make sure that the cache is large enough to hold the new values. + For example, the KV cache could be pre-allocated with the max sequence length, and you can use + cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. + + Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be + rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos + and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at + indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). + + See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. + + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Note: Does not support backward pass. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, + or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) + page_block_size must be a multiple of 256. + v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, + or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) + k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate + k with k_cache, starting at the indices specified by cache_seqlens. + v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. + rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding + to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. + rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. + cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the + KV cache. + block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. + cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. + If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. + If the indices are not distinct, and k and v are provided, the values updated in the cache + might come from any of the duplicate indices. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. + If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, + rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 + (i.e. GPT-NeoX style). + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + + Return: + out: (batch_size, seqlen, nheads, headdim). + """ +``` + +To see how these functions are used in a multi-head attention layer (which +includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py). + +## Changelog + +### 2.0: Complete rewrite, 2x faster +Upgrading from FlashAttention (1.x) to FlashAttention-2 + +These functions have been renamed: +- `flash_attn_unpadded_func` -> `flash_attn_varlen_func` +- `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func` +- `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func` + +If the inputs have the same sequence lengths in the same batch, it is simpler +and faster to use these functions: +```python +flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False) +``` +```python +flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False) +``` +### 2.1: Change behavior of causal flag + +If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the +bottom right corner of the attention matrix, instead of the top-left corner. + +For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = +masked out) is: +v2.0: + 1 0 0 0 0 + 1 1 0 0 0 +v2.1: + 1 1 1 1 0 + 1 1 1 1 1 + +If seqlen_q = 5 and seqlen_k = 2, the causal mask is: +v2.0: + 1 0 + 1 1 + 1 1 + 1 1 + 1 1 +v2.1: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 +If the row of the mask is all zero, the output will be zero. + +### 2.2: Optimize for inference + +Optimize for inference (iterative decoding) when query has very small sequence +length (e.g., query sequence length = 1). The bottleneck here is to load KV +cache as fast as possible, and we split the loading across different thread +blocks, with a separate kernel to combine results. + +See the function `flash_attn_with_kvcache` with more features for inference +(perform rotary embedding, updating KV cache inplace). + +Thanks to the xformers team, and in particular Daniel Haziza, for this +collaboration. + +### 2.3: Local (i.e., sliding window) attention + +Implement sliding window attention (i.e., local attention). Thanks to [Mistral +AI](https://mistral.ai/) and in particular Timothée Lacroix for this +contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model. + +### 2.4: ALiBi (attention with linear bias), deterministic backward pass. + +Implement ALiBi (Press et al., 2021). Thanks to Sanghun Cho from Kakao Brain for this contribution. + +Implement deterministic backward pass. Thanks to engineers from [Meituan](www.meituan.com) for this contribution. + +### 2.5: Paged KV cache. + +Support paged KV cache (i.e., [PagedAttention](https://arxiv.org/abs/2309.06180)). +Thanks to @beginlner for this contribution. + +### 2.6: Softcapping. + +Support attention with softcapping, as used in Gemma-2 and Grok models. +Thanks to @Narsil and @lucidrains for this contribution. + +### 2.7: Compatibility with torch compile + +Thanks to @ani300 for this contribution. + +## Performance + +We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory). + +We currently have benchmarks for these GPUs: +* [A100](#a100) +* [H100](#h100) + + + +### A100 + +We display FlashAttention speedup using these parameters: +* Head dimension 64 or 128, hidden dimension 2048 (i.e. either 32 or 16 heads). +* Sequence length 512, 1k, 2k, 4k, 8k, 16k. +* Batch size set to 16k / seqlen. + +#### Speedup + +![FlashAttention speedup on A100 80GB SXM5 with FP16/BF16](assets/flash2_a100_fwd_bwd_benchmark.png) + +#### Memory + +![FlashAttention memory](assets/flashattn_memory.jpg) + +We show memory savings in this graph (note that memory footprint is the same no matter if you use dropout or masking). +Memory savings are proportional to sequence length -- since standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length. +We see 10X memory savings at sequence length 2K, and 20X at 4K. +As a result, FlashAttention can scale to much longer sequence lengths. + +### H100 + +![FlashAttention speedup on H100 SXM5 with FP16/BF16](assets/flash2_h100_fwd_bwd_benchmark.png) + +## Full model code and training script + +We have released the full GPT model +[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/gpt.py). +We also provide optimized implementations of other layers (e.g., MLP, LayerNorm, +cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x +compared to the baseline implementation from Huggingface, reaching up to 225 +TFLOPs/sec per A100, equivalent to 72% model FLOPs utilization (we don't need +any activation checkpointing). + +We also include a training +[script](https://github.com/Dao-AILab/flash-attention/tree/main/training) to +train GPT2 on Openwebtext and GPT3 on The Pile. + +## Triton implementation of FlashAttention + +Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton: +https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py + +As Triton is a higher-level language than CUDA, it might be easier to understand +and experiment with. The notations in the Triton implementation are also closer +to what's used in our paper. + +We also have an experimental implementation in Triton that support attention +bias (e.g. ALiBi): +https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py + + +## Tests +We test that FlashAttention produces the same output and gradient as a reference +implementation, up to some numerical tolerance. In particular, we check that the +maximum numerical error of FlashAttention is at most twice the numerical error +of a baseline implementation in Pytorch (for different head dimensions, input +dtype, sequence length, causal / non-causal). + +To run the tests: +```sh +pytest -q -s tests/test_flash_attn.py +``` +## When you encounter issues + +This new release of FlashAttention-2 has been tested on several GPT-style +models, mostly on A100 GPUs. + +If you encounter bugs, please open a GitHub Issue! + +## Tests +To run the tests: +```sh +pytest tests/test_flash_attn_ck.py +``` + +## Citation +If you use this codebase, or otherwise found our work valuable, please cite: +``` +@inproceedings{dao2022flashattention, + title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, + author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, + booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, + year={2022} +} +@inproceedings{dao2023flashattention2, + title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning}, + author={Dao, Tri}, + booktitle={International Conference on Learning Representations (ICLR)}, + year={2024} +} +``` diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py new file mode 100644 index 00000000000..6158eddc174 --- /dev/null +++ b/benchmarks/benchmark_attn.py @@ -0,0 +1,420 @@ +from collections import namedtuple +from functools import partial +import math +import os +from typing import NamedTuple +import torch +import torch.nn as nn +import torch.nn.functional as F + +import time + +try: + import cudnn +except ImportError: + cudnn = None +# cudnn = None + +Timing = NamedTuple('timing', [('mean', float)]) + + +from einops import rearrange, repeat + +# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler +from flash_attn.cute.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler + +try: + from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func +except ImportError: + flash_attn_func = None + flash_attn_varlen_func = None +from flash_attn.cute.interface import flash_attn_func as flash_attn_func_python +from flash_attn.cute.interface import flash_attn_varlen_func as flash_attn_varlen_func_python +try: + from flash_attn_interface import flash_attn_func as flash_attn_func_v3 + from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3 +except ImportError: + flash_attn_func_v3 = None + flash_attn_varlen_func_v3 = None + +if torch.cuda.get_device_capability()[0] != 9: + flash_attn_func_v3 = None +# flash_attn_func_v3 = None + +flash_attn_func = None + +from triton.testing import do_bench + +def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs): + # # Warmup + # for _ in range(5): + # func(*args, **kwargs) + # time.sleep(1) + # return benchmark_forward(func, *args, **kwargs, repeats=repeats, verbose=verbose, desc=desc)[1] + # s = torch.cuda.Stream() + # s.wait_stream(torch.cuda.current_stream()) + # with torch.cuda.stream(s): + # for _ in range(2): + # out = func(*args, **kwargs) + # torch.cuda.current_stream().wait_stream(s) + # graph = torch.cuda.CUDAGraph() + # with torch.cuda.graph(graph): + # out = func(*args, **kwargs) + # time_f = benchmark_forward(lambda: graph.replay(), repeats=repeats, verbose=verbose, desc=desc) + # # return time_f[1].mean + # return time_f[1] + return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3) + + +def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(None, None)): + if causal: + avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2 + else: + if window_size == (None, None): + avg_seqlen = seqlen_k + else: + row_idx = torch.arange(seqlen_q, device='cuda') + col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) if window_size[0] is not None else torch.zeros_like(row_idx) + col_right = torch.minimum(row_idx + seqlen_k - seqlen_q + window_size[1], torch.tensor(seqlen_k - 1)) if window_size[1] is not None else torch.full_like(row_idx, seqlen_k - 1) + avg_seqlen = (col_right - col_left + 1).float().mean().item() + return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) + + +def convert_to_cudnn_type(torch_type): + if torch_type == torch.float16: + return cudnn.data_type.HALF + elif torch_type == torch.bfloat16: + return cudnn.data_type.BFLOAT16 + elif torch_type == torch.float32: + return cudnn.data_type.FLOAT + elif torch_type == torch.int32: + return cudnn.data_type.INT32 + elif torch_type == torch.int64: + return cudnn.data_type.INT64 + else: + raise ValueError("Unsupported tensor data type.") + + +def cudnn_spda_setup(q, k, v, causal=False, window_size_left=None): + b, nheads, seqlen_q, headdim = q.shape + _, nheads_k, seqlen_k, _ = k.shape + headdim_v = v.shape[-1] + assert v.shape == (b, nheads_k, seqlen_k, headdim_v) + assert cudnn is not None, 'CUDNN is not available' + q_gpu, k_gpu, v_gpu = q, k, v + o_gpu = torch.empty((b, nheads, seqlen_q, headdim_v), dtype=q.dtype, device=q.device) + stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device) + graph = cudnn.pygraph( + io_data_type=convert_to_cudnn_type(q.dtype), + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + q = graph.tensor_like(q_gpu.detach()) + k = graph.tensor_like(k_gpu.detach()) + v = graph.tensor_like(v_gpu.detach()) + + o, stats = graph.sdpa( + name="sdpa", + q=q, + k=k, + v=v, + is_inference=False, + attn_scale=1.0 / math.sqrt(headdim), + # use_causal_mask_bottom_right=causal or window_size_left is not None, + use_causal_mask=causal or window_size_left is not None, + sliding_window_length=window_size_left if window_size_left is not None and not causal else None, + ) + + o.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride()) + stats.set_output(True).set_data_type(cudnn.data_type.FLOAT) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + + variant_pack = { + q: q_gpu, + k: k_gpu, + v: v_gpu, + o: o_gpu, + stats: stats_gpu, + } + + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + + def run(*args, **kwargs): + graph.execute(variant_pack, workspace) + return o_gpu + + return run + + +def cudnn_spda_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_left=None): + b, nheads, seqlen_q, headdim = q.shape + _, nheads_k, seqlen_k, _ = k.shape + headdim_v = v.shape[-1] + assert v.shape == (b, nheads_k, seqlen_k, headdim_v) + assert g.shape == (b, nheads, seqlen_q, headdim_v) + assert o.shape == (b, nheads, seqlen_q, headdim_v) + assert lse.shape == (b, nheads, seqlen_q, 1) + assert cudnn is not None, 'CUDNN is not available' + q_gpu, k_gpu, v_gpu, o_gpu, g_gpu = q, k, v, o, g + dq_gpu = torch.empty_like(q_gpu) + dk_gpu = torch.empty_like(k_gpu) + dv_gpu = torch.empty_like(v_gpu) + graph = cudnn.pygraph( + io_data_type=convert_to_cudnn_type(q.dtype), + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + q = graph.tensor_like(q_gpu.detach()) + k = graph.tensor_like(k_gpu.detach()) + v = graph.tensor_like(v_gpu.detach()) + o = graph.tensor_like(o_gpu.detach()) + g = graph.tensor_like(g_gpu.detach()) + stats = graph.tensor_like(lse.detach()) + + dq, dk, dv = graph.sdpa_backward( + name="sdpa_backward", + q=q, + k=k, + v=v, + o=o, + dO=g, + stats=stats, + attn_scale=1.0 / math.sqrt(headdim), + # use_causal_mask_bottom_right=causal or window_size_left is not None, + use_causal_mask=causal or window_size_left is not None, + sliding_window_length=window_size_left if window_size_left is not None and not causal else None, + use_deterministic_algorithm=False, + ) + + dq.set_output(True).set_dim(dq_gpu.shape).set_stride(dq_gpu.stride()) + dk.set_output(True).set_dim(dk_gpu.shape).set_stride(dk_gpu.stride()) + dv.set_output(True).set_dim(dv_gpu.shape).set_stride(dv_gpu.stride()) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + + variant_pack = { + q: q_gpu, + k: k_gpu, + v: v_gpu, + o: o_gpu, + g: g_gpu, + stats: lse, + dq: dq_gpu, + dk: dk_gpu, + dv: dv_gpu, + } + + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + + def run(*args, **kwargs): + graph.execute(variant_pack, workspace) + return dq_gpu, dk_gpu, dv_gpu + + return run + + +torch.manual_seed(0) +repeats = 10 +dropout_p = 0.0 +causal = False +dtype = torch.bfloat16 +# dtype = torch.float8_e4m3fn +dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype +device = 'cuda' +verbose = True +varlen = False +has_backward = True +page_size = None +# page_size = 128 +softcap = 0.0 +V_colmajor = False +deterministic = False +batch_size = 2 +# seqlen = 2048 +seqlen = 8192 +# seqlen = 4096 +# seqlen = 2047 +dim = 2048 +# headdim = 128 +# headdim = 64 +headdim = 256 +# for headdim in [64, 128, 256]: +# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] +# bs_seqlen_vals = [(32, 1024), (16, 2048), (8, 4096), (4, 8192), (2, 16384), (1, 32768)] +# bs_seqlen_vals = [(32, 512), (16, 1024)] +# bs_seqlen_vals = [(2, 64 * 132)] +bs_seqlen_vals = [(4, 8192)] +# bs_seqlen_vals = [(1, 16 * 1024)] +time_f = {} +time_b = {} + +# for headdim in [64, 128, 256]: +# for headdim in [64, 96, 128, 192]: +# for headdim in [64, 96, 128, 192, 256]: +# for headdim in [64, 96, 128]: +# for headdim in [64, 128, 256]: +# for headdim in [64, 96, 128, 192, 256]: +for headdim in [128]: + # nheads = dim // headdim + nheads = 32 if headdim <= 64 else 16 if headdim <= 192 else 8 + # nheads = 128 + # headdim = 64 + # batch_size = 64 + # seqlen = 512 + # nheads = 8 + # headdim = 128 + nheads_kv = nheads + # nheads_kv = nheads // 8 + # nheads_kv = 1 + # headdim_v = headdim + headdim_v = 128 if headdim == 192 else headdim + # headdim_v = 512 + has_qv = headdim == 64 and headdim_v == 512 + # has_qv = False + # sinks = torch.randn(nheads, dtype=torch.bfloat16, device=device) + sinks = None + + for batch_size, seqlen in bs_seqlen_vals: + num_splits = 0 + # window_size = (-1, -1) + window_size = (None, None) + window_size_fa = (-1, -1) + # window_size = (seqlen // 2 - 1, 0) + pack_gqa = None + # seqlen_q = 64 + seqlen_q = seqlen + leftpad_k = None + # leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32) + q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=has_backward) + k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=has_backward) + v = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=has_backward) + q, k, v = [x.detach().to(dtype).requires_grad_(has_backward) for x in [q, k, v]] + v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_(has_backward) + v_fa3 = v if not V_colmajor else v_colmajor + qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) if has_qv else None + # q = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) + # k = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) + # v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim_v), device=device, dtype=torch.int32).to(dtype) + g = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) + o = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) + stats = torch.randn(batch_size, seqlen_q, nheads, 1, device=device, dtype=torch.float32) + if varlen: + q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_(has_backward) for x in [q, k, v]] + cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q + cu_seqlens_k = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen if page_size is None else None + # cu_seqlens_q = torch.tensor([0, 248, 249, 250, 251, 252, 253, 254, 255, 256], device=device, dtype=torch.int32) + # q_unpad = q_unpad[:256] + # seqlen_q = 256 + # cu_seqlens_q = torch.tensor([0, 376, 377, 378, 379, 380, 381, 382, 383, 384], device=device, dtype=torch.int32) + # q_unpad = q_unpad[:384] + # seqlen_q = 384 + if page_size is not None: + assert seqlen % page_size == 0 + k_paged, v_paged = [rearrange(x, "b (n p) h d -> (b n) p h d", p=page_size) for x in [k, v]] + page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32), + "(b s) -> b s", s=seqlen // page_size) + else: + page_table = None + + for causal in [False, True]: + # for causal in [True]: + print(f"\n### {headdim = }, {causal = }, {seqlen = }, {batch_size = }, {nheads = }, {nheads_kv = }, {varlen = }, {deterministic = } ###") + nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size) + if cudnn is not None: + # if False: + if headdim <= 256 and dtype != torch.float8_e4m3fn: + cudnn_spda = cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal, window_size_left=window_size[0]) + if has_backward and headdim == headdim_v: + cudnn_spda_bwd = cudnn_spda_bwd_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), o.transpose(1, 2), g.transpose(1, 2), stats.transpose(1, 2), causal=causal, window_size_left=window_size[0]) + if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None: + # if False: + if not varlen: + m0 = time_fwd(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2') + else: + m0 = time_fwd(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2') + time_f[(causal, headdim, batch_size, seqlen), "Flash2"] = m0.mean + if has_backward: + time.sleep(1) + if not varlen: + _, m0b = benchmark_backward(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, + repeats=repeats, verbose=False, desc='Fav2') + else: + _, m0b = benchmark_backward(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, + repeats=repeats, verbose=False, desc='Fav2') + time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = m0b.mean + # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=True) + + if cudnn is not None: + # if False: + if headdim <= 256 and dtype != torch.float8_e4m3fn: + time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark + m2 = time_fwd(cudnn_spda, repeats=repeats, verbose=verbose, desc='CuDNN') + time_f[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2.mean + if has_backward: + time.sleep(1) + m2b = time_fwd(cudnn_spda_bwd, repeats=repeats, verbose=verbose, desc='CuDNN') + time_b[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2b.mean + # pytorch_profiler(cudnn_spda, backward=False) + # pytorch_profiler(cudnn_spda_bwd, backward=False) + time.sleep(1) + if flash_attn_func_v3 is not None: + if not varlen: + # m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + # pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa) + else: + m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits) + time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean + if flash_attn_func_python is not None: + if not varlen: + m1_py = time_fwd(flash_attn_func_python, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, learnable_sink=sinks, softcap=softcap, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3 python') + else: + m1_py = time_fwd(flash_attn_varlen_func_python, q_unpad, k_unpad if page_size is None else k_paged, v_unpad if page_size is None else v_paged, cu_seqlens_q, cu_seqlens_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3 python') + if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_v3 is not None and has_backward: + time.sleep(1) + if not varlen: + _, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, softcap=softcap, repeats=repeats, verbose=False, desc='Fav3') + else: + _, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, + repeats=repeats, verbose=False, desc='Fav3') + time_b[(causal, headdim, batch_size, seqlen), "Flash3"] = m1b.mean + time.sleep(1) + # if not varlen: + # pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, deterministic=deterministic, backward=True) + # else: + # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True) + # benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy') + if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_python is not None and has_backward: + if not varlen: + _, m1b_py = benchmark_backward(flash_attn_func_python, q, k, v, causal=causal, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav4 python') + else: + _, m1b_py = benchmark_backward(flash_attn_varlen_func_python, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, causal=causal, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav4 python') + + if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None: + # if False: + print(f'FAv2 fwd: {m0.mean * 1e3:.3f}ms, {(nFLOPS / m0.mean * 1e-12):.1f} TFLOPS') + if has_backward: + print(f'FAv2 bwd: {m0b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m0b.mean * 1e-12):.1f} TFLOPS') + if cudnn is not None: + print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS') + if has_backward: + print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS') + if flash_attn_func_v3 is not None: + print(f'FAv3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS') + if dtype != torch.float8_e4m3fn and headdim == headdim_v and has_backward: + print(f'FAv3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS') + + if flash_attn_func_python is not None: + print(f'FA Python fwd: {m1_py.mean * 1e3:.3f}ms, {(nFLOPS / m1_py.mean * 1e-12):.1f} TFLOPS') + if dtype != torch.float8_e4m3fn and headdim == headdim_v and has_backward: + print(f'FA Python bwd: {m1b_py.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b_py.mean * 1e-12):.1f} TFLOPS') diff --git a/benchmarks/benchmark_causal.py b/benchmarks/benchmark_causal.py index 6c4797c83e0..297055df78d 100644 --- a/benchmarks/benchmark_causal.py +++ b/benchmarks/benchmark_causal.py @@ -17,12 +17,6 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func -try: - from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax -except ImportError: - scaled_upper_triang_masked_softmax = None - - def attention_pytorch(qkv, dropout_p=0.0, causal=True): """ Arguments: @@ -52,27 +46,6 @@ def attention_pytorch(qkv, dropout_p=0.0, causal=True): return output.to(dtype=qkv.dtype) -def attention_megatron(qkv): - """ - Arguments: - qkv: (batch_size, seqlen, 3, nheads, head_dim) - Output: - output: (batch_size, seqlen, nheads, head_dim) - """ - batch_size, seqlen, _, nheads, d = qkv.shape - q, k, v = qkv.unbind(dim=2) - q = rearrange(q, 'b t h d -> (b h) t d') - k = rearrange(k, 'b s h d -> (b h) d s') - softmax_scale = 1.0 / math.sqrt(d) - # Preallocate attn_weights for `baddbmm` - scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device) - scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), - '(b h) t s -> b h t s', h=nheads) - attention = scaled_upper_triang_masked_softmax(scores, None, scale=1.0) - output = torch.einsum('bhts,bshd->bthd', attention, v) - return output.to(dtype=qkv.dtype) - - torch.manual_seed(0) repeats = 30 batch_size = 8 diff --git a/csrc/composable_kernel b/csrc/composable_kernel index 888317e698e..13f6d635653 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit 888317e698e9803c62bd38568abc9e05d7709f33 +Subproject commit 13f6d635653bd5ffbfcac8577f1ef09590c23d78 diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index aa66087f6d4..d8f663284f3 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -431,7 +431,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); @@ -516,7 +516,7 @@ std::vector mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &out_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. @@ -649,7 +649,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); @@ -864,7 +864,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); @@ -1081,7 +1081,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); @@ -1360,7 +1360,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size = round_multiple(head_size_og, 8); - const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); diff --git a/csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu deleted file mode 100644 index e34dd2454ba..00000000000 --- a/csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_bwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim160(params, stream); -} - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu deleted file mode 100644 index 5089d988d99..00000000000 --- a/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_bwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim160(params, stream); -} - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu deleted file mode 100644 index 0272c579755..00000000000 --- a/csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_bwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim160(params, stream); -} - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu deleted file mode 100644 index d3d5d98d12d..00000000000 --- a/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_bwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim160(params, stream); -} - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index b719cf98870..72e7a333b3a 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -102,7 +102,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If Is_local, set Is_causal to false - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; if (smem_size_dq_dk_dv >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( @@ -261,26 +261,6 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { }); } -template -void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 160; - int device; - cudaGetDevice(&device); - int max_smem_per_block; - cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); - if (status_ != cudaSuccess) { - C10_CUDA_CHECK(status_); - } - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - if (max_smem_per_block >= 116 * 1024) { - run_flash_bwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_bwd, Is_dropout, Is_causal>(params, stream); - } - }); -} - template void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu deleted file mode 100644 index 27d9e9d8a7b..00000000000 --- a/csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_fwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim160(params, stream); -} - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu deleted file mode 100644 index 943e508eb16..00000000000 --- a/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_fwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim160(params, stream); -} - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu deleted file mode 100644 index 92904627b9f..00000000000 --- a/csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_fwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim160(params, stream); -} - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu deleted file mode 100644 index 7b3749e2551..00000000000 --- a/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_fwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim160(params, stream); -} - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 227f3c25729..83ab14581a1 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -76,7 +76,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If return_softmax, set IsEvenMNConst to false to reduce number of templates // If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel; + auto kernel = &flash_fwd_kernel; // auto kernel = &flash_fwd_kernel; // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); // auto kernel = &flash_fwd_kernel; @@ -117,7 +117,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_splitkv_kernel; + auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; if (smem_size >= 48 * 1024) { @@ -257,34 +257,6 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { }); } -template -void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 160; - auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); - bool is_sm8x = cc_major == 8 && cc_minor > 0; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // For A100, H100, 128 x 32 is the fastest. - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - // and 128 x 64 with 8 warps is the fastest for non-causal. - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - }); -} - template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu deleted file mode 100644 index f5167b33392..00000000000 --- a/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_fwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu deleted file mode 100644 index ee02db1a341..00000000000 --- a/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_fwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu deleted file mode 100644 index 2b0472038f7..00000000000 --- a/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_fwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu deleted file mode 100644 index 2b833bd537b..00000000000 --- a/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_fwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/generate_kernels.py b/csrc/flash_attn/src/generate_kernels.py index 755ee8fea58..4d8e7ffdfd0 100644 --- a/csrc/flash_attn/src/generate_kernels.py +++ b/csrc/flash_attn/src/generate_kernels.py @@ -10,7 +10,7 @@ } SM = [80] # Sm80 kernels support up to -HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 256] +HEAD_DIMENSIONS = [32, 64, 96, 128, 192, 256] IS_CAUSAL = ["false", "true"] NAMESPACE_INCLUDE = '#include "namespace_config.h"\n' diff --git a/csrc/flash_attn/src/static_switch.h b/csrc/flash_attn/src/static_switch.h index a57702f6ce7..70d14daf69d 100644 --- a/csrc/flash_attn/src/static_switch.h +++ b/csrc/flash_attn/src/static_switch.h @@ -101,9 +101,6 @@ } else if (HEADDIM <= 128) { \ constexpr static int kHeadDim = 128; \ return __VA_ARGS__(); \ - } else if (HEADDIM <= 160) { \ - constexpr static int kHeadDim = 160; \ - return __VA_ARGS__(); \ } else if (HEADDIM <= 192) { \ constexpr static int kHeadDim = 192; \ return __VA_ARGS__(); \ diff --git a/csrc/flash_attn_ck/mha_bwd.cpp b/csrc/flash_attn_ck/mha_bwd.cpp index 1f016a4a4e6..083494f5b0c 100644 --- a/csrc/flash_attn_ck/mha_bwd.cpp +++ b/csrc/flash_attn_ck/mha_bwd.cpp @@ -133,9 +133,12 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, dv.data_ptr(), nullptr, // dbias dq_acc.data_ptr(), // dq_acc - nullptr, // seqstart_q - nullptr, // seqstart_k + nullptr, // seqstart_q_ptr + nullptr, // seqstart_k_ptr + nullptr, // seqlen_q_ptr nullptr, // seqlen_k_ptr + nullptr, // cu_seqlen_q_ptr + nullptr, // cu_seqlen_k_ptr seqlen_q, seqlen_k, b, @@ -220,7 +223,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num if (is_causal) { window_size_right = 0; } bool is_dropout = p_dropout > 0.0; +#ifdef HIPIFY_V2 + auto stream = at::cuda::getCurrentCUDAStream().stream(); +#else auto stream = at::cuda::getCurrentHIPStream().stream(); +#endif auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, @@ -399,4 +406,4 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num } return { dq, dk, dv, softmax_d }; -} \ No newline at end of file +} diff --git a/csrc/flash_attn_ck/mha_fwd.cpp b/csrc/flash_attn_ck/mha_fwd.cpp index a3867682168..0229e777cd5 100644 --- a/csrc/flash_attn_ck/mha_fwd.cpp +++ b/csrc/flash_attn_ck/mha_fwd.cpp @@ -19,11 +19,12 @@ fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask, dtype, false, // is_group_mode true, // is_v_rowmajor + false, // has_logits_soft_cap mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, has_dropout, - false}; // do_fp8_static_quant + quant_scale_enum::no_scale}; // qscale_type } fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, @@ -94,12 +95,18 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, k.data_ptr(), v.data_ptr(), alibi_slopes_ptr, // bias + nullptr, // q_descale_ptr + nullptr, // k_descale_ptr + nullptr, // v_descale_ptr has_dropout_randval ? dropout_randval.data_ptr() : nullptr, has_lse ? softmax_lse.data_ptr() : nullptr, out.data_ptr(), - nullptr, // seqstart_q - nullptr, // seqstart_k - nullptr, + nullptr, // seqstart_q_ptr + nullptr, // seqstart_k_ptr + nullptr, // seqlen_q_ptr + nullptr, // seqlen_k_ptr + nullptr, // cu_seqlen_q_ptr + nullptr, // cu_seqlen_k_ptr seqlen_q, seqlen_k, b, @@ -109,8 +116,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, h, // nhead h_k, // nhead_k softmax_scale, // scale_s - 1, // scale_p - 1, // scale_o + 0.0f, // logits_soft_cap stride_q, stride_k, stride_v, @@ -134,6 +140,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, mask.left, mask.right, static_cast(mask.type), + 0, // min_seqlen_q p_dropout, has_dropout_randval, drop_seed_offset}; @@ -269,7 +276,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num if (seqlen_k > 0) { auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); +#ifdef HIPIFY_V2 + auto stream = at::cuda::getCurrentCUDAStream().stream(); +#else auto stream = at::cuda::getCurrentHIPStream().stream(); +#endif ck_tile::stream_config stream_config{stream}; auto traits = diff --git a/csrc/flash_attn_ck/mha_fwd_kvcache.cpp b/csrc/flash_attn_ck/mha_fwd_kvcache.cpp index bcb8e3bbb96..27866f1902e 100644 --- a/csrc/flash_attn_ck/mha_fwd_kvcache.cpp +++ b/csrc/flash_attn_ck/mha_fwd_kvcache.cpp @@ -33,7 +33,8 @@ fmha_fwd_splitkv_traits get_ck_fmha_fwd_splitkv_traits(const mask_info &mask, head_size, dtype, false, // is_group_mode - true, // is_v_rowmajor + true, // is_v_rowmajor + false, // has_logits_soft_cap mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, diff --git a/csrc/flash_attn_ck/mha_varlen_bwd.cpp b/csrc/flash_attn_ck/mha_varlen_bwd.cpp index bfeb3b770d0..3cd01c32d48 100644 --- a/csrc/flash_attn_ck/mha_varlen_bwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_bwd.cpp @@ -139,9 +139,12 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, dv.data_ptr(), nullptr, // dbias dq_acc.data_ptr(), // dq_acc - seqlens_q.data_ptr(), // seqstart_q - seqlens_k.data_ptr(), // seqstart_k + seqlens_q.data_ptr(), // seqstart_q_ptr + seqlens_k.data_ptr(), // seqstart_k_ptr + nullptr, // seqlen_q_ptr nullptr, // seqlen_k_ptr + nullptr, // cu_seqlen_q_ptr + nullptr, // cu_seqlen_k_ptr total_q, total_k, b, diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index 6274750f588..00b0fcd5738 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -17,13 +17,14 @@ fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask, return fmha_fwd_traits{head_size, head_size, dtype, - true, // is_group_mode - true, // is_v_rowmajor + true, // is_group_mode + true, // is_v_rowmajor + false, // has_logits_soft_cap mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, has_dropout, - false}; // do_fp8_static_quant + quant_scale_enum::no_scale}; // qscale_type } fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask_info &mask, @@ -35,8 +36,9 @@ fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask_info &m return fmha_fwd_splitkv_traits{head_size, head_size, dtype, - true, // is_group_mode - true, // is_v_rowmajor + true, // is_group_mode + true, // is_v_rowmajor + false, // has_logits_soft_cap mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, @@ -114,12 +116,18 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, k.data_ptr(), v.data_ptr(), alibi_slopes_ptr, // bias + nullptr, // q_descale_ptr + nullptr, // k_descale_ptr + nullptr, // v_descale_ptr has_dropout_randval ? dropout_randval.data_ptr() : nullptr, has_lse ? softmax_lse.data_ptr() : nullptr, out.data_ptr(), - seqlens_q.data_ptr(), // seqstart_q - seqlens_k.data_ptr(), // seqstart_k - nullptr, // seqlen_kpads + seqlens_q.data_ptr(), // seqstart_q_ptr + seqlens_k.data_ptr(), // seqstart_k_ptr + nullptr, // seqlen_q_ptr + nullptr, // seqlen_k_ptr + nullptr, // cu_seqlen_q_ptr + nullptr, // cu_seqlen_kv_ptr total_q, total_k, b, @@ -129,8 +137,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, h, // nhead h_k, // nhead_k softmax_scale, // scale_s - 1, // scale_p - 1, // scale_o + 0.0f, // logits_soft_cap stride_q, stride_k, stride_v, @@ -154,6 +161,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, mask.left, mask.right, static_cast(mask.type), + 0, // min_seqlen_q p_dropout, has_dropout_randval, drop_seed_offset}; @@ -465,7 +473,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si } if (max_seqlen_k > 0) { +#ifdef HIPIFY_V2 + auto stream = at::cuda::getCurrentCUDAStream().stream(); +#else auto stream = at::cuda::getCurrentHIPStream().stream(); +#endif ck_tile::stream_config stream_config{stream}; if (paged_KV) diff --git a/csrc/ft_attention/README.md b/csrc/ft_attention/README.md deleted file mode 100644 index 97feb78cc1c..00000000000 --- a/csrc/ft_attention/README.md +++ /dev/null @@ -1,14 +0,0 @@ -# Attention kernel from FasterTransformer - -This CUDA extension wraps the single-query attention [kernel](https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp) from -FasterTransformer v5.2.1 for benchmarking purpose. - -```sh -cd csrc/ft_attention && pip install . -``` - -As of 2023-09-17, this extension is no longer used in the FlashAttention repo. -FlashAttention now has implemented -[`flash_attn_with_kvcache`](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attention_interface.py) -with all the features of this `ft_attention` kernel (and more). - diff --git a/csrc/ft_attention/cuda_bf16_fallbacks.cuh b/csrc/ft_attention/cuda_bf16_fallbacks.cuh deleted file mode 100644 index f5641f61609..00000000000 --- a/csrc/ft_attention/cuda_bf16_fallbacks.cuh +++ /dev/null @@ -1,257 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_fallbacks.cuh -/* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "cuda_bf16_wrapper.h" -#include - -namespace fastertransformer { - -#ifdef ENABLE_BF16 -inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float2 f_val; - f_val.x = __low2float(val); - f_val.y = __high2float(val); - return f_val; -#else - return __bfloat1622float2(val); -#endif -} - -inline __device__ int16_t bf1622int16(__nv_bfloat162 val) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float2 f_val; - f_val.x = max(min(__low2float(val), 127.f), -128.f); - f_val.y = max(min(__high2float(val), 127.f), -128.f); - union { int8_t int8[2]; int16_t int16; }; - int8[0] = static_cast(static_cast(f_val.x)); - int8[1] = static_cast(static_cast(f_val.y)); - return int16; -#else - val = __hmin2(val, make_bfloat162(127., 127.)); - val = __hmax2(val, make_bfloat162(-128., -128.)); - union { int8_t int8[2]; int16_t int16; }; - int8[0] = static_cast(static_cast(val.x)); - int8[1] = static_cast(static_cast(val.y)); - return int16; -#endif -} - -inline __device__ __nv_bfloat162 float22bf162(const float2 val) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __floats2bfloat162_rn(val.x, val.y); -#else - return __float22bfloat162_rn(val); -#endif -} - -inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - __nv_bfloat162 val2; - val2.x = val; - val2.y = val; - return val2; -#else - return __bfloat162bfloat162(val); -#endif -} - -inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); -#else - return __hadd2(x, y); -#endif -} - -inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) ); -#else - return __hadd(x, y); -#endif -} - -inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl - fyl, fxh - fyh); -#else - return __hsub2(x, y); -#endif -} - -inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) ); -#else - return __hsub(x, y); -#endif -} - -inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl * fyl, fxh * fyh); -#else - return __hmul2(x, y); -#endif -} - -inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) ); -#else - return __hmul(x, y); -#endif -} - -inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh, fzl, fzh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - fzl = __low2float(z); - fzh = __high2float(z); - return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh); -#else - return __hfma2(x, y, z); -#endif -} - -inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z)); -#else - return __hfma(x, y, z); -#endif -} - -inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh; - fxl = __low2float(x); - fxh = __high2float(x);; - return __floats2bfloat162_rn(expf(fxl), expf(fxh)); -#else - return h2exp(x); -#endif -} - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) -inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); }; -inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); }; - -inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y) -{ - __nv_bfloat162 t; t.x = x; t.y = y; return t; -} - -#endif - -inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c)); -#else - return a + b + c; -#endif -} - -inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d)); -#else - return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d); -#endif -} - -inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch); -#else - return a + b + c; -#endif -} - -inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c)); -#else - return a * b * c; -#endif -} - -inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch); -#else - return a * b * c; -#endif -} - -inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch, fdl, fdh; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - fdl = __low2float(d); - fdh = __high2float(d); - return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh); -#else - return a * b * c + d; -#endif -} - -#endif // ENABLE_BF16 - -} // namespace fastertransformer diff --git a/csrc/ft_attention/cuda_bf16_wrapper.h b/csrc/ft_attention/cuda_bf16_wrapper.h deleted file mode 100644 index efb6e798730..00000000000 --- a/csrc/ft_attention/cuda_bf16_wrapper.h +++ /dev/null @@ -1,23 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h -/* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#ifdef ENABLE_BF16 -#include -#endif diff --git a/csrc/ft_attention/decoder_masked_multihead_attention.cu b/csrc/ft_attention/decoder_masked_multihead_attention.cu deleted file mode 100644 index 13306f76868..00000000000 --- a/csrc/ft_attention/decoder_masked_multihead_attention.cu +++ /dev/null @@ -1,149 +0,0 @@ -// Adapted from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "decoder_masked_multihead_attention.h" -#include "decoder_masked_multihead_attention_utils.h" -#include "cuda_bf16_wrapper.h" -#include -#include -#include - -#include "decoder_masked_multihead_attention_template.hpp" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - auto kernel = mmha::masked_multihead_attention_kernel; \ - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ - dim3 grid(params.nnz_head_idx == nullptr ? params.num_heads : params.nnz_heads, params.batch_size); \ - kernel<<>>(params) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// !!! Specialize the launcher for Cross attention -template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; - // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION); - if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream); - } - else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream); - } - else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#undef MMHA_LAUNCH_KERNEL - -template -void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - switch (params.hidden_size_per_head) { - case 32: - mmha_launch_kernel(params, stream); - break; - case 48: - mmha_launch_kernel(params, stream); - break; - case 64: - mmha_launch_kernel(params, stream); - break; - case 80: - mmha_launch_kernel(params, stream); - break; - case 96: - mmha_launch_kernel(params, stream); - break; - case 128: - mmha_launch_kernel(params, stream); - break; - case 160: - mmha_launch_kernel(params, stream); - break; - case 192: - mmha_launch_kernel(params, stream); - break; - case 224: - mmha_launch_kernel(params, stream); - break; - case 256: - mmha_launch_kernel(params, stream); - break; - default: - assert(false); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream) -{ - multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream); -} -#endif -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream) -{ - multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/ft_attention/decoder_masked_multihead_attention.h b/csrc/ft_attention/decoder_masked_multihead_attention.h deleted file mode 100644 index 3c79f88b856..00000000000 --- a/csrc/ft_attention/decoder_masked_multihead_attention.h +++ /dev/null @@ -1,192 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "cuda_bf16_wrapper.h" -#include -#include -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define CHECK_CUDA(call) \ - do { \ - cudaError_t status_ = call; \ - if (status_ != cudaSuccess) { \ - fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ - exit(1); \ - } \ - } while (0) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// The structure of parameters for the masked multihead attention kernel. -// -// We use the following terminology to describe the different dimensions. -// -// B: Batch size (number of sequences), -// L: Sequence length, -// D: Hidden dimension, -// H: Number of heads, -// Dh: Hidden dimension per head - Dh = D / H. - -template -struct Multihead_attention_params_base { - - // The output buffer. Dimensions B x D. - T* out = nullptr; - - // The input Qs and the associated bias. Dimensions B x D and D, resp. - const T *q = nullptr, *q_bias = nullptr; - // The input Ks and the associated bias. Dimensions B x D and D, resp. - const T *k = nullptr, *k_bias = nullptr; - // The input Vs and the associated bias. Dimensions B x D and D, resp. - const T *v = nullptr, *v_bias = nullptr; - - // The cache for the Ks. The size must be at least B x L x D. - T* k_cache = nullptr; - // The cache for the Vs. The size must be at least B x L x D. - T* v_cache = nullptr; - // The indirections to use for cache when beam sampling. - const int* cache_indir = nullptr; - - // Stride to handle the case when KQV is a single buffer - int stride_q = 0; - int stride_k = 0; - int stride_v = 0; - - // The batch size. - int batch_size = 0; - // The beam width - int beam_width = 0; - // The sequence length. - int memory_max_len = 0; - // The number of heads (H). - int num_heads = 0; - int num_heads_kv = 0; - int num_heads_q_kv_ratio = 0; - // The hidden dimension per head (Dh). - int hidden_size_per_head = 0; - // The per-head latent space reserved for rotary embeddings. - int rotary_embedding_dim = 0; - bool neox_rotary_style = false; - float rotary_base = 0.0f; - // The maximum length of input sentences. - int max_input_length = 0; - // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention? - int timestep = 0; - // The current timestep of each sentences (support different timestep for different sentences) - - // The 1.f / sqrt(Dh). Computed on the host. - float inv_sqrt_dh = 0.0f; - - // Used when we have some input context like gpt - const int* total_padding_tokens = nullptr; - - const bool* masked_tokens = nullptr; - const int* prefix_prompt_lengths = nullptr; - int max_prefix_prompt_length = 0; - - const T* relative_attention_bias = nullptr; - int relative_attention_bias_stride = 0; - // The slope per head of linear position bias to attention score (H). - const T* linear_bias_slopes = nullptr; - - const T* ia3_key_weights = nullptr; - const T* ia3_value_weights = nullptr; - const int* ia3_tasks = nullptr; - - const float* qkv_scale_out = nullptr; - const float* attention_out_scale = nullptr; - int int8_mode = 0; - - const T *rotary_cos = nullptr; - const T *rotary_sin = nullptr; - - const int *nnz_head_idx = nullptr; - int nnz_heads = 0; -}; - -template -struct Multihead_attention_params: public Multihead_attention_params_base { - // output cross attentions - float* cross_attention_out = nullptr; - int max_decoder_seq_len = 0; - bool is_return_cross_attentions = false; - - // allows to exist attention eary - bool* finished = nullptr; - - // required in case of cross attention - // will need it here till if constexpr in c++17 - int* memory_length_per_sample = nullptr; - - // required in case of masked attention with different length - const int* length_per_sample = nullptr; -}; - -template -struct Multihead_attention_params: public Multihead_attention_params_base { - // output cross attentions - float* cross_attention_out = nullptr; - int max_decoder_seq_len = 0; - bool is_return_cross_attentions = false; - - // allows to exist attention eary - bool* finished = nullptr; - - // required in case of cross attention - int* memory_length_per_sample = nullptr; - - // required in case of masked attention with different length - const int* length_per_sample = nullptr; -}; - -template -using Masked_multihead_attention_params = Multihead_attention_params; - -template -using Cross_multihead_attention_params = Multihead_attention_params; - -template -struct outputCrossAttentionParam { - // max decoder output length - int max_decoder_seq_len = 0; - T* cross_attention_out = nullptr; - bool is_return_cross_attentions = false; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream); -#endif -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream); -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp b/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp deleted file mode 100644 index 2ae1b2425b8..00000000000 --- a/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp +++ /dev/null @@ -1,1619 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "decoder_masked_multihead_attention.h" -#include "decoder_masked_multihead_attention_utils.h" -#include "cuda_bf16_wrapper.h" -#include "cuda_bf16_fallbacks.cuh" -#include -#include -#include - -// #define MMHA_USE_HMMA_FOR_REDUCTION - -// Below are knobs to extend FP32 accumulation for higher FP16 accuracy - -// Does not seem to affect the accuracy that much -#define MMHA_USE_FP32_ACUM_FOR_FMA - -// Seems to slightly improve the accuracy -#define MMHA_USE_FP32_ACUM_FOR_OUT - -#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT) - // Does not seem to improve the accuracy - //#define MMHA_USE_FP32_ACUM_FOR_LOGITS -#endif - -namespace mmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// -// We use the following terminology to describe the different dimensions. -// -// B: Batch size (number of sequences), -// L: Sequence length, -// D: Hidden dimension, -// H: Number of heads, -// Dh: Hidden dimension per head - Dh = D / H. -// -// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use -// 64, 128 and 256 threads per block. -// -// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to -// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The -// cache buffer helps with memory accesses and contains keys with bias. -// -// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and -// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The -// values for x are chosen to create chunks of 16 bytes. -// -// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs -// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At -// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an -// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32. -// -// After that loop, a parallel softmax is computed across the different Q * K^T values stored in -// shared memory. -// -// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many -// timesteps are computed by loop iteration. As with the keys, the values are read from a cache -// except for the current timestep. The layout of the cache buffer for the values is much simpler -// as it is [B, H, L, Dh]. -// - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Qk_vec_ { -}; - -template<> -struct Qk_vec_ { - using Type = float; -}; -template<> -struct Qk_vec_ { - using Type = float2; -}; -template<> -struct Qk_vec_ { - using Type = float4; -}; -template<> -struct Qk_vec_ { - using Type = float4; -}; -template<> -struct Qk_vec_ { - using Type = uint32_t; -}; -template<> -struct Qk_vec_ { - using Type = uint32_t; -}; -template<> -struct Qk_vec_ { - using Type = uint2; -}; -template<> -struct Qk_vec_ { - using Type = uint4; -}; -#ifdef ENABLE_BF16 -template<> -struct Qk_vec_<__nv_bfloat16, 32> { - using Type = __nv_bfloat162; -}; -template<> -struct Qk_vec_<__nv_bfloat16, 64> { - using Type = __nv_bfloat162; -}; -template<> -struct Qk_vec_<__nv_bfloat16, 128> { - using Type = bf16_4_t; -}; -template<> -struct Qk_vec_<__nv_bfloat16, 256> { - using Type = bf16_8_t; -}; -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct K_vec_ { -}; - -template<> -struct K_vec_ { - using Type = float; -}; -template<> -struct K_vec_ { - using Type = float2; -}; -template<> -struct K_vec_ { - using Type = float4; -}; -template<> -struct K_vec_ { - using Type = uint32_t; -}; -template<> -struct K_vec_ { - using Type = uint2; -}; -template<> -struct K_vec_ { - using Type = uint4; -}; -#ifdef ENABLE_BF16 -template<> -struct K_vec_<__nv_bfloat16, 4> { - using Type = __nv_bfloat162; -}; -template<> -struct K_vec_<__nv_bfloat16, 2> { - using Type = bf16_4_t; -}; -template<> -struct K_vec_<__nv_bfloat16, 1> { - using Type = bf16_8_t; -}; -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct V_vec_ { -}; - -template<> -struct V_vec_ { - using Type = float; -}; -template<> -struct V_vec_ { - using Type = float2; -}; -template<> -struct V_vec_ { - using Type = float4; -}; -template<> -struct V_vec_ { - using Type = uint32_t; -}; -template<> -struct V_vec_ { - using Type = uint2; -}; -template<> -struct V_vec_ { - using Type = uint4; -}; -#ifdef ENABLE_BF16 -template<> -struct V_vec_<__nv_bfloat16, 2> { - using Type = __nv_bfloat162; -}; -template<> -struct V_vec_<__nv_bfloat16, 4> { - using Type = bf16_4_t; -}; -template<> -struct V_vec_<__nv_bfloat16, 8> { - using Type = bf16_8_t; -}; -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA -template -struct Qk_vec_acum_fp32_ { -}; - -template<> -struct Qk_vec_acum_fp32_ { - using Type = float; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = float4; -}; -// template<> struct Qk_vec_acum_fp32_ { using Type = float; }; -template<> -struct Qk_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float8_; -}; -template<> -struct Qk_vec_acum_fp32_<__nv_bfloat16> { - using Type = float; -}; -template<> -struct Qk_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float8_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct K_vec_acum_fp32_ { -}; - -template<> -struct K_vec_acum_fp32_ { - using Type = float; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = float4; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float8_; -}; -template<> -struct K_vec_acum_fp32_<__nv_bfloat16> { - using Type = float; -}; -template<> -struct K_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float8_; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT -template -struct V_vec_acum_fp32_ { -}; - -template<> -struct V_vec_acum_fp32_ { - using Type = float; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = float4; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float8_; -}; -#ifdef ENABLE_BF16 -template<> -struct V_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float8_; -}; -#endif // ENABLE_BF16 -#endif -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) -{ -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using K_vec_acum = typename K_vec_acum_fp32_::Type; -#else - using K_vec_acum = K_vec; -#endif - // Compute the parallel products for Q*K^T (treat vector lanes separately). - K_vec_acum qk_vec = mul(q[0], k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); - } - - // Finalize the reduction across lanes. - float qk = sum(qk_vec); -#pragma unroll - for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(uint32_t(-1), qk, mask); - } - return qk; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Qk_dot { - template - static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) - { - return qk_dot_(q, k); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b) -{ - float4 c; - float zero = 0.f; - asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" - " {%0, %1, %2, %3}, \n" - " {%4, %5}, \n" - " {%6}, \n" - " {%7, %7, %7, %7}; \n" - - : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) - : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N]) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using K_vec_acum = typename K_vec_acum_fp32_::Type; -#else - using K_vec_acum = uint32_t; -#endif - K_vec_acum qk_vec = mul(q[0], k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); - } -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - uint32_t qk_vec_ = float2_to_half2(qk_vec); - return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; -#else - return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x; -#endif -#else - return 0.f; -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Qk_dot { - template - static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N]) - { -#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION) - return qk_hmma_dot_(q, k); -#else - return qk_dot_<4>(q, k); -#endif // defined MMHA_USE_HMMA_FOR_REDUCTION - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float block_sum(float* red_smem, float sum) -{ - - // Decompose the thread index into warp / lane. - int warp = threadIdx.x / WARP_SIZE; - int lane = threadIdx.x % WARP_SIZE; - -// Compute the sum per warp. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - // Warp leaders store the data to shared memory. - if (lane == 0) { - red_smem[warp] = sum; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The warps compute the final sums. - if (lane < WARPS_PER_BLOCK) { - sum = red_smem[lane]; - } - -// Parallel reduction inside the warp. -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - // Broadcast to other threads. - return __shfl_sync(uint32_t(-1), sum, 0); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float& dst, float src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint16_t& dst, float src) -{ - dst = float_to_half(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint32_t& dst, float2 src) -{ - dst = float2_to_half2(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -#ifdef ENABLE_BF16 -inline __device__ void convert_from_float(__nv_bfloat16& dst, float src) -{ - dst = __float2bfloat16(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst = __float22bfloat162_rn(src); -#else - dst = __floats2bfloat162_rn(src.x, src.y); -#endif -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint2& dst, Float4_ src) -{ - dst.x = float2_to_half2(src.x); - dst.y = float2_to_half2(src.y); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint2& dst, float4 src) -{ - convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint4& dst, Float8_ src) -{ - dst.x = float2_to_half2(src.x); - dst.y = float2_to_half2(src.y); - dst.z = float2_to_half2(src.z); - dst.w = float2_to_half2(src.w); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst.x = __float22bfloat162_rn(src.x); - dst.y = __float22bfloat162_rn(src.y); -#else - dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); - dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(bf16_4_t& dst, float4 src) -{ - convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst.x = __float22bfloat162_rn(src.x); - dst.y = __float22bfloat162_rn(src.y); - dst.z = __float22bfloat162_rn(src.z); - dst.w = __float22bfloat162_rn(src.w); -#else - dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); - dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); - dst.z = __floats2bfloat162_rn(src.z.x, src.z.y); - dst.w = __floats2bfloat162_rn(src.w.x, src.w.y); -#endif -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float2& dst, float2 src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float4& dst, float4 src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float convert_to_float(float4 u) -{ - return u.x; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float convert_to_float(uint4 u) -{ - float2 tmp = half2_to_float2(u.x); - return tmp.x; -} - -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float cast_to_float(float u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 cast_to_float(float2 u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 cast_to_float(float4 u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ cast_to_float(Float4_ u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ cast_to_float(Float8_ u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 cast_to_float(uint32_t u) -{ - return half2_to_float2(u); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ cast_to_float(uint2 u) -{ - Float4_ tmp; - tmp.x = half2_to_float2(u.x); - tmp.y = half2_to_float2(u.y); - return tmp; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ cast_to_float(uint4 u) -{ - Float8_ tmp; - tmp.x = half2_to_float2(u.x); - tmp.y = half2_to_float2(u.y); - tmp.z = half2_to_float2(u.z); - tmp.w = half2_to_float2(u.w); - return tmp; -} - -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float float_from_int8(int8_t u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 float_from_int8(int16_t u) -{ - union { - int16_t int16; - int8_t int8[2]; - }; - int16 = u; - return make_float2(int8[0], int8[1]); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 float_from_int8(int32_t u) -{ - union { - int32_t int32; - int8_t int8[4]; - }; - int32 = u; - return make_float4(int8[0], int8[1], int8[2], int8[3]); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// clang-format off -inline __device__ Float8_ float_from_int8(int64_t u) -{ - union { - int64_t int64; - int16_t int16[4]; - }; - int64 = u; - return Float8_ {float_from_int8(int16[0]), - float_from_int8(int16[1]), - float_from_int8(int16[2]), - float_from_int8(int16[3])}; -} -// clang-format on - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ int8_t cast_to_int8(float val) -{ - union { - int8_t int8[2]; - int16_t int16; - }; - asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); - return int8[0]; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ int32_t cast_to_int8(float4 val) -{ - union { - int8_t int8[4]; - int32_t int32; - }; - int8[0] = cast_to_int8(val.x); - int8[1] = cast_to_int8(val.y); - int8[2] = cast_to_int8(val.z); - int8[3] = cast_to_int8(val.w); - return int32; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ int64_t cast_to_int8(Float8_ val) -{ - union { - int8_t int8[8]; - int64_t int64; - }; - int8[0] = cast_to_int8(val.x.x); - int8[1] = cast_to_int8(val.x.y); - int8[2] = cast_to_int8(val.y.x); - int8[3] = cast_to_int8(val.y.y); - int8[4] = cast_to_int8(val.z.x); - int8[5] = cast_to_int8(val.z.y); - int8[6] = cast_to_int8(val.w.x); - int8[7] = cast_to_int8(val.w.y); - return int64; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ __host__ T div_up(T m, T n) -{ - return (m + n - 1) / n; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline size_t smem_size_in_bytes(const Multihead_attention_params& params, - int threads_per_value, - int threads_per_block) -{ - // The amount of shared memory needed to store the Q*K^T values in float. - const int max_timesteps = min(params.timestep, params.memory_max_len); - size_t qk_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16; - - // The extra memory needed if we are not using floats for the final logits. - size_t logits_sz = 0; -#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS - if (sizeof(T) != 4) { - // TDOD - logits_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 4 * sizeof(T) : - div_up(max_timesteps + 1, 4) * 4 * sizeof(T); - } -#endif - - // The total size needed during softmax. - size_t softmax_sz = qk_sz + logits_sz; - - // The number of partial rows to reduce in the final reduction. - int rows_per_red = threads_per_block / threads_per_value; - // The amount of storage needed to finalize the outputs. - size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(T) / 2; - - size_t transpose_rotary_size = 0; - if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { - transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(T); - } - - // The max. - return max(max(softmax_sz, red_sz), transpose_rotary_size); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ constexpr uint32_t shfl_mask(int threads) -{ - return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The type of the inputs. Supported types: float and half. - typename T, - // The hidden dimension per head. - int Dh, - int Dh_MAX, - // The number of threads per key. - int THREADS_PER_KEY, - // The number of threads per value. - int THREADS_PER_VALUE, - // The number of threads in a threadblock. - int THREADS_PER_BLOCK, - bool DO_CROSS_ATTENTION> -__global__ void masked_multihead_attention_kernel(Multihead_attention_params params) -{ - - // Make sure the hidden dimension per head is a multiple of the number of threads per key. - static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); - // Make sure the hidden dimension per head is a multiple of the number of threads per value. - static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); - - // The size of a warp. - constexpr int WARP_SIZE = 32; - // The number of warps in a threadblock. - constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; - - // Use smem_size_in_bytes (above) to determine the amount of shared memory. - extern __shared__ char smem_[]; - - // The shared memory for the Q*K^T values and partial logits in softmax. - float* qk_smem = reinterpret_cast(smem_); - - // The shared memory for the logits. For FP32, that's the same buffer as qk_smem. - char* logits_smem_ = smem_; -#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS - if (sizeof(T) != 4) { - // TODO - change to tlength - const int max_timesteps = min(params.timestep, params.memory_max_len); - logits_smem_ += - (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16; - } - T* logits_smem = reinterpret_cast(logits_smem_); -#else - float* logits_smem = reinterpret_cast(logits_smem_); -#endif - - // The shared memory to do the final reduction for the output values. Reuse qk_smem. - T* out_smem = reinterpret_cast(smem_); - - // The shared memory buffers for the block-wide reductions. One for max, one for sum. - __shared__ float red_smem[WARPS_PER_BLOCK * 2]; - - // A vector of Q or K elements for the current timestep. - using Qk_vec = typename Qk_vec_::Type; - - // Use alignment for safely casting the shared buffers as Qk_vec. - // Shared memory to store Q inputs. - __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX]; - - // This is one of the reasons we should have a separate kernel for cross attention - __shared__ __align__(sizeof(Qk_vec)) T bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1]; - - // A vector of Q or K elements for the current timestep. - using Qk_vec = typename Qk_vec_::Type; - // The number of elements per vector. - constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); - // Make sure the hidden size per head is a multiple of the vector size. - static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); - // We will use block wide reduction if needed - // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); - // The number of vectors per warp. - constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; - - // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread - // owns x elements, we have to decompose the linear index into chunks of x values and the posi- - // tion of the thread in that chunk. - - // The number of elements in a chunk of 16B (that's the x in the above formula). - constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); - // The number of K vectors in 16B. - constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec); - - // The batch/beam idx - const int bi = blockIdx.y; - if (params.finished != nullptr && params.finished[bi] == true) { - return; - } - // The beam idx - const int beami = bi % params.beam_width; - // The "beam-aware" batch idx - const int bbi = bi / params.beam_width; - // The head. - // const int hi = blockIdx.x; - const int hi = params.nnz_head_idx == nullptr ? blockIdx.x : params.nnz_head_idx[blockIdx.x]; - const int hi_kv = hi / params.num_heads_q_kv_ratio; - // Combine the batch and the head indices. - const int bhi = bi * params.num_heads + hi; - const int bhi_kv = bi * params.num_heads_kv + hi_kv; - // Combine the "beam-aware" batch idx and the head indices. - const int bbhi = bbi * params.beam_width * params.num_heads_kv + hi_kv; - // The thread in the block. - const int tidx = threadIdx.x; - - const bool handle_kv = !DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0); - - // While doing the product Q*K^T for the different keys we track the max. - float qk_max = -FLT_MAX; - - float qk = 0.0F; - - int q_base_offset = (params.stride_q == 0) ? bhi * Dh : bi * params.stride_q + hi * Dh; - int k_base_offset = (params.stride_k == 0) ? bhi_kv * Dh : bi * params.stride_k + hi_kv * Dh; - int v_base_offset = (params.stride_v == 0) ? bhi_kv * Dh : bi * params.stride_v + hi_kv * Dh; - - const size_t bi_seq_len_offset = bi * params.memory_max_len; - - // int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 : - (params.length_per_sample == nullptr) ? - params.timestep : - params.length_per_sample[bi] + params.max_prefix_prompt_length; - const int first_step = max(0, tlength + 1 - params.memory_max_len); - const int tlength_circ = tlength % params.memory_max_len; - - // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. - const bool is_masked = tidx >= QK_VECS_PER_WARP; - - // The offset in the Q and K buffer also accounts for the batch. - int q_offset = q_base_offset + tidx * QK_VEC_SIZE; - int k_offset = k_base_offset + tidx * QK_VEC_SIZE; - // The offset in the bias buffer. - int q_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; - int k_bias_offset = hi_kv * Dh + tidx * QK_VEC_SIZE; - - const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr; - const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0; - - // Trigger the loads from the Q and K buffers. - Qk_vec q; - zero(q); - if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto q_scaling = params.qkv_scale_out[0]; - const auto q_quant = - *reinterpret_cast(&reinterpret_cast(params.q)[q_offset]); - - convert_from_float(q, mul(q_scaling, float_from_int8(q_quant))); - } - else { - q = *reinterpret_cast(¶ms.q[q_offset]); - } - } - - Qk_vec k; - zero(k); - if (DO_CROSS_ATTENTION) { - // The 16B chunk written by the thread. - int co = tidx / QK_VECS_IN_16B; - // The position of the thread in that 16B chunk. - int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; - - // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. - int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + - // params.timestep*QK_ELTS_IN_16B + - tlength * QK_ELTS_IN_16B + ci; - k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? - *reinterpret_cast(¶ms.k_cache[offset]) : - k; - } - else { - if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto k_scaling = params.qkv_scale_out[1]; - const auto k_quant = - *reinterpret_cast(&reinterpret_cast(params.k)[k_offset]); - - convert_from_float(k, mul(k_scaling, float_from_int8(k_quant))); - } - else { - k = *reinterpret_cast(¶ms.k[k_offset]); - } - } - } - - // Trigger the loads from the Q and K bias buffers. - Qk_vec q_bias; - zero(q_bias); - q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ? - *reinterpret_cast(¶ms.q_bias[q_bias_offset]) : - q_bias; - - Qk_vec k_bias; - zero(k_bias); - if (handle_kv) { - k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ? - *reinterpret_cast(¶ms.k_bias[k_bias_offset]) : - k_bias; - } - - // Computes the Q/K values with bias. - q = add(q, q_bias); - if (handle_kv) { - k = add(k, k_bias); - } - if (do_ia3 && !is_masked) { - k = mul( - k, - *reinterpret_cast( - ¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE])); - } - - // Padded len - const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi]; - if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) { - if (handle_kv) { - if (params.rotary_cos == nullptr) { - apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); - } else { - apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, - params.rotary_cos + bi * params.rotary_embedding_dim / 2, - params.rotary_sin + bi * params.rotary_embedding_dim / 2); - } - } - else { - if (params.rotary_cos == nullptr) { - apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); - } else { - apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, - params.rotary_cos + bi * params.rotary_embedding_dim / 2, - params.rotary_sin + bi * params.rotary_embedding_dim / 2); - } - } - } - else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { - const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim; - - T* q_smem = reinterpret_cast(smem_); - T* k_smem = q_smem + params.rotary_embedding_dim; - - const int half_rotary_dim = params.rotary_embedding_dim / 2; - const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim; - const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim; - const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts - - assert(half_rotary_dim % QK_VEC_SIZE == 0); - - if (do_rotary) { - *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx) = q; - - if (handle_kv) { - *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx) = k; - } - } - - __syncthreads(); - - const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; - constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1; - if (do_rotary) { - mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch); - - if (handle_kv) { - mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch); - - if (params.rotary_cos == nullptr) { - mmha::apply_rotary_embedding( - q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); - } else { - mmha::apply_rotary_embedding( - q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, - params.rotary_cos + bi * params.rotary_embedding_dim / 2, - params.rotary_sin + bi * params.rotary_embedding_dim / 2); - } - - mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch); - } - else { - if (params.rotary_cos == nullptr) { - mmha::apply_rotary_embedding( - q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base); - } else { - mmha::apply_rotary_embedding( - q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, - params.rotary_cos + bi * params.rotary_embedding_dim / 2, - params.rotary_sin + bi * params.rotary_embedding_dim / 2); - } - } - mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch); - } - - __syncthreads(); - - if (do_rotary) { - q = *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx); - if (handle_kv) { - k = *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx); - } - } - - __syncthreads(); - } - - if (!is_masked) { - // Store the Q values to shared memory. - *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; - - // Store Dh values of k_bias into smem, since will need to add later - // if params.timestep == 0 - if (DO_CROSS_ATTENTION && params.timestep == 0) { - *reinterpret_cast(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias; - } - - // Write the K values to the global memory cache. - // - // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory - // system. We designed it this way as it allows much better memory loads (and there are many - // more loads) + the stores are really "write and forget" since we won't need the ack before - // the end of the kernel. There's plenty of time for the transactions to complete. - - // The 16B chunk written by the thread. - int co = tidx / QK_VECS_IN_16B; - // The position of the thread in that 16B chunk. - int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; - - // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. - int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + - // params.timestep*QK_ELTS_IN_16B + - tlength_circ * QK_ELTS_IN_16B + ci; - - if (handle_kv && hi % params.num_heads_q_kv_ratio == 0) { - // Trigger the stores to global memory. - if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { - *reinterpret_cast(¶ms.k_cache[offset]) = k; - } - } - - // Compute \sum_i Q[i] * K^T[i] for the current timestep. -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using Qk_vec_acum = typename Qk_vec_acum_fp32_::Type; -#else - using Qk_vec_acum = Qk_vec; -#endif - qk = dot(q, k); - if (QK_VECS_PER_WARP <= WARP_SIZE) { -#pragma unroll - for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); - } - } - } - - if (QK_VECS_PER_WARP > WARP_SIZE) { - constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; - qk = block_sum(&red_smem[WARPS_PER_RED], qk); - } - - // Store that value in shared memory. Keep the Q*K^T value in register for softmax. - if (tidx == 0) { - // Normalize qk. - qk *= params.inv_sqrt_dh; - if (params.relative_attention_bias != nullptr) { - qk = add(qk, - params.relative_attention_bias[hi * params.relative_attention_bias_stride - * params.relative_attention_bias_stride - + (tlength - padd_len) * params.relative_attention_bias_stride - + (tlength - padd_len)]); - } - // We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0. - - qk_max = qk; - qk_smem[tlength - first_step] = qk; - // qk_smem[params.timestep] = qk; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The type of queries and keys for the math in the Q*K^T product. - using K_vec = typename K_vec_::Type; - // The number of elements per vector. - constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); - // Make sure the hidden size per head is a multiple of the vector size. - static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); - // The number of elements per thread. - constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; - // The number of vectors per thread. - constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; - - // The position the first key loaded by each thread from the cache buffer (for this B * H). - int ko = tidx / THREADS_PER_KEY; - // The position of the thread in the chunk of keys. - int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; - - static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD); - - // Load the Q values from shared memory. The values are reused during the loop on K. - K_vec q_vec[K_VECS_PER_THREAD]; -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - q_vec[ii] = *reinterpret_cast(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); - } - - K_vec k_bias_vec[DO_CROSS_ATTENTION ? K_VECS_PER_THREAD : 1]; - if (DO_CROSS_ATTENTION && params.timestep == 0) { -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - k_bias_vec[ii] = *reinterpret_cast(&bias_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); - } - } - - // The number of timesteps loaded per iteration. - constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; - // The number of keys per warp. - constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; - - // The base pointer for the key in the cache buffer. - T* k_cache = ¶ms.k_cache[bhi_kv * params.memory_max_len * Dh + ki]; - // Base pointer for the beam's batch, before offsetting with indirection buffer - T* k_cache_batch = ¶ms.k_cache[bbhi * params.memory_max_len * Dh + ki]; - - // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). - // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; - int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; - - // prefix prompt length if has - const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi]; - - // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values. - const bool has_beams = params.cache_indir != nullptr; - const int* beam_indices = has_beams ? ¶ms.cache_indir[bi_seq_len_offset] : nullptr; - - for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) { - const int ti_circ = ti % params.memory_max_len; - - // The keys loaded from the key cache. - K_vec k[K_VECS_PER_THREAD]; - K_vec k_vec_zero; - zero(k_vec_zero); -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * params.memory_max_len + ti_circ; - // if( ti < params.timestep ) { - const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len); - if (ti < tlength) { - if (!within_bounds) { - k[ii] = k_vec_zero; - } - else { - if (has_beams) { - const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; - k[ii] = *reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]); - } - else { - k[ii] = *reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]); - } - } - // add bias and update k_cache - if (DO_CROSS_ATTENTION && params.timestep == 0) { - k[ii] = add(k[ii], k_bias_vec[ii]); - - if (do_ia3) { - k[ii] = mul( - k[ii], - *reinterpret_cast( - ¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + ki - + ii * THREADS_PER_KEY * K_VEC_SIZE])); - } - - if (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len) { - *reinterpret_cast(&k_cache[jj * QK_ELTS_IN_16B]) = k[ii]; - } - } - } - } - - // Perform the dot product and normalize qk. - // - // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!! - float qk = Qk_dot::dot(q_vec, k) * params.inv_sqrt_dh; - bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; - - // Store the product to shared memory. There's one qk value per timestep. Update the max. - // if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) { - if (ti < tlength && tidx % THREADS_PER_KEY == 0) { - if (params.relative_attention_bias != nullptr) { - qk = add(qk, - params.relative_attention_bias[hi * params.relative_attention_bias_stride - * params.relative_attention_bias_stride - + tlength * params.relative_attention_bias_stride + ti]); - } - if (params.linear_bias_slopes != nullptr) { - // Apply the linear position bias: (ki - qi) * slope[hi]. - // The padding token locates between the input context and the generated tokens. - // We need to remove the number of padding tokens in the distance computation. - // ti : 0 1 2 3 4 5 6 7 8 9(tlength) - // token: i i i i p p p o o o where i=input, p=pad, o=output. - // e.g. ti = 2, dist = (9 - 3) - 2 = 4. - int max_context_length = params.max_prefix_prompt_length + params.max_input_length; - float dist = (ti < max_context_length ? ti + padd_len : ti) - tlength; - - qk += mul(params.linear_bias_slopes[hi], dist); - } - qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); - qk_smem[ti - first_step] = qk; - } - } - -// Perform the final reduction to compute the max inside each warp. -// -// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the -// group so it's not needed to run the reduction inside the group (again). -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - - // Decompose the thread index into warp and lane. - const int warp = tidx / WARP_SIZE; - const int lane = tidx % WARP_SIZE; - - // The warp leader writes the max to shared memory. - if (lane == 0) { - red_smem[warp] = qk_max; - } - - // Make sure the products are in shared memory. - __syncthreads(); - - // The warps finalize the reduction. - qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - - // Broadcast to all the threads in the warp. - qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); - - // Compute the logits and start the sum. - float sum = 0.f; - // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { - for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { - bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; - float logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max); - sum += logit; - qk_smem[ti - first_step] = logit; - } - - // Compute the sum. - sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); - - // Normalize the logits. - float inv_sum = __fdividef(1.f, sum + 1.e-6f); - // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { - const size_t cross_attention_out_offset = - params.is_return_cross_attentions ? - bhi * params.max_decoder_seq_len * params.memory_max_len + params.timestep * params.memory_max_len : - 0; - for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { - float logit = qk_smem[ti - first_step] * inv_sum; - if (params.is_return_cross_attentions) { - params.cross_attention_out[cross_attention_out_offset + ti] = logit; - } - convert_from_float(logits_smem[ti - first_step], logit); - } - - // Put Values part below so we leverage __syncthreads - // from the previous step - - // The number of elements per vector. - constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; - // A vector of V elements for the current timestep. - using V_vec = typename V_vec_::Type; - - // The value computed by this thread. - int vo = tidx / THREADS_PER_VALUE; - // The hidden dimensions computed by this particular thread. - int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; - - // The base pointer for the value in the cache buffer. - T* v_cache = ¶ms.v_cache[bhi_kv * params.memory_max_len * Dh + vi]; - // Base pointer for the beam's batch, before offsetting with indirection buffer - T* v_cache_batch = ¶ms.v_cache[bbhi * params.memory_max_len * Dh + vi]; - - // The number of values processed per iteration of the loop. - constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; - - // One group of threads computes the product(s) for the current timestep. - V_vec v_bias; - zero(v_bias); - // if( vo == params.timestep % V_PER_ITER ) { - if (Dh == Dh_MAX || vi < Dh) { - if (handle_kv) { - if (vo == tlength % V_PER_ITER) { - // Trigger the loads from the V bias buffer. - if (params.v_bias != nullptr) { - v_bias = *reinterpret_cast(¶ms.v_bias[hi_kv * Dh + vi]); - } - if (DO_CROSS_ATTENTION) { - *reinterpret_cast(&bias_smem[vi]) = v_bias; - } - } - } - } - - // From previous, before values, step - // Also make sure the logits are in shared memory. - __syncthreads(); - - // Values continued -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - using V_vec_acum = typename V_vec_acum_fp32_::Type; -#else - using V_vec_acum = V_vec; -#endif - // The partial outputs computed by each thread. - V_vec_acum out; - zero(out); - - // Loop over the timesteps to compute the partial outputs. - // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) { - if (Dh == Dh_MAX || vi < Dh) { - for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { - const int ti_circ = ti % params.memory_max_len; - - // Fetch offset based on cache_indir when beam sampling - const int beam_src = (params.cache_indir != nullptr) ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0; - const int beam_offset = beam_src * params.num_heads * params.memory_max_len * Dh; - // Load the values from the cache. - V_vec v = *reinterpret_cast(&v_cache_batch[beam_offset + ti_circ * Dh]); - if (DO_CROSS_ATTENTION && params.timestep == 0) { - v = add(v, *reinterpret_cast(&bias_smem[vi])); - if (do_ia3) { - v = mul( - v, - *reinterpret_cast( - ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); - } - *reinterpret_cast(&v_cache[ti * Dh]) = v; - } - // Load the logits from shared memory. -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - float logit = logits_smem[ti - first_step]; - out = fma(logit, cast_to_float(v), out); -#else - T logit = logits_smem[ti - first_step]; - - // Update the partial sums. - out = fma(logit, v, out); -#endif - } - } - - // One group of threads computes the product(s) for the current timestep. - // if( vo == params.timestep % V_PER_ITER ) { - if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) { - - V_vec v; - if (DO_CROSS_ATTENTION) { - v = *reinterpret_cast(&v_cache[tlength * Dh]); - } - else { - // Trigger the loads from the V buffer. - const auto v_offset = v_base_offset + vi; - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto v_scaling = params.qkv_scale_out[2]; - const auto v_quant = - *reinterpret_cast(&reinterpret_cast(params.v)[v_offset]); - - convert_from_float(v, mul(v_scaling, float_from_int8(v_quant))); - } - else { - v = *reinterpret_cast(¶ms.v[v_offset]); - } - // Trigger the loads from the V bias buffer. - // V_vec v_bias = *reinterpret_cast(¶ms.v_bias[hi*Dh + vi]); - } - - // Compute the V values with bias. - if (handle_kv) { - v = add(v, v_bias); - - if (do_ia3) { - v = mul( - v, - *reinterpret_cast( - ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); - } - - // Store the values with bias back to global memory in the cache for V. - if (hi % params.num_heads_q_kv_ratio == 0) { - //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; - *reinterpret_cast(&v_cache[tlength_circ * Dh]) = v; - } - } - - // Initialize the output value with the current timestep. -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - // out = fma(logits_smem[params.timestep], cast_to_float(v), out); - out = fma(logits_smem[tlength - first_step], cast_to_float(v), out); -#else - // out = fma(logits_smem[params.timestep], v, out); - out = fma(logits_smem[tlength - first_step], v, out); -#endif - } - - // Make sure we can start writing to shared memory. - __syncthreads(); - - // Run the final reduction amongst the different groups computing different partial outputs. - if (Dh == Dh_MAX || vi < Dh) { -#pragma unroll - for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) { - - // The midpoint in the number of active groups. - int midpoint = active_groups / 2; - - // The upper part of active threads store to shared memory. - if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - convert_from_float(*reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), out); -#else - *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; -#endif - } - __syncthreads(); - - // The bottom warps update their values. - if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { - out = add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); - } - __syncthreads(); - } - } - - // Output the final values. - if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - out = mul(*params.attention_out_scale, out); - *reinterpret_cast(&(reinterpret_cast(params.out)[bhi * Dh + vi])) = - cast_to_int8(out); - } - else { - convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), out); - } -#else - // TODO: support int8_mode? - *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = out; -#endif - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace mmha - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream); diff --git a/csrc/ft_attention/decoder_masked_multihead_attention_utils.h b/csrc/ft_attention/decoder_masked_multihead_attention_utils.h deleted file mode 100644 index 98875aba9b8..00000000000 --- a/csrc/ft_attention/decoder_masked_multihead_attention_utils.h +++ /dev/null @@ -1,2017 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "cuda_bf16_wrapper.h" -#include "cuda_bf16_fallbacks.cuh" -#include - -using namespace fastertransformer; - -namespace mmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Float8_ { - float2 x; - float2 y; - float2 z; - float2 w; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Float4_ { - float2 x; - float2 y; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -struct bf16_4_t { - __nv_bfloat162 x; - __nv_bfloat162 y; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct bf16_8_t { - __nv_bfloat162 x; - __nv_bfloat162 y; - __nv_bfloat162 z; - __nv_bfloat162 w; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct num_elems; -template<> -struct num_elems { - static constexpr int value = 1; -}; -template<> -struct num_elems { - static constexpr int value = 2; -}; -template<> -struct num_elems { - static constexpr int value = 4; -}; -template<> -struct num_elems { - static constexpr int value = 4; -}; -template<> -struct num_elems { - static constexpr int value = 8; -}; - -template<> -struct num_elems { - static constexpr int value = 2; -}; -template<> -struct num_elems { - static constexpr int value = 4; -}; -template<> -struct num_elems { - static constexpr int value = 8; -}; - -#ifdef ENABLE_BF16 -template<> -struct num_elems<__nv_bfloat162> { - static constexpr int value = 2; -}; -template<> -struct num_elems { - static constexpr int value = 4; -}; -template<> -struct num_elems { - static constexpr int value = 8; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct packed_type; -template -struct packed_type { - using type = T; -}; -template<> -struct packed_type { - using type = int16_t; -}; -template<> -struct packed_type { - using type = int32_t; -}; -template<> -struct packed_type { - using type = int64_t; -}; - -template<> -struct packed_type { - using type = float2; -}; -template<> -struct packed_type { - using type = float4; -}; -template<> -struct packed_type { - using type = Float8_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float add(float a, float b) -{ - return a + b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 add(float2 a, float2 b) -{ - float2 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 add(float4 a, float4 b) -{ - float4 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) -{ - return a + b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) -{ - return bf16hadd2(a, b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) -{ - bf16_4_t c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) -{ - bf16_8_t c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint16_t add(uint16_t a, uint16_t b) -{ - uint16_t c; - asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t add(uint32_t a, uint32_t b) -{ - uint32_t c; - asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint2 add(uint2 a, uint2 b) -{ - uint2 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint4 add(uint4 a, uint4 b) -{ - uint4 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint16_t float_to_half(float f) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; -#if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // Is it better? - float zero = 0.f; - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(zero), "f"(f)); -#else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); -#endif - return tmp.u16[0]; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t float2_to_half2(float2 f) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); -#else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); -#endif - return tmp.u32; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float half_to_float(uint16_t h) -{ - float f; - asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); - return f; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 half2_to_float2(uint32_t v) -{ - uint16_t lo, hi; - asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); - return make_float2(half_to_float(lo), half_to_float(hi)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float add(float a, uint16_t b) -{ - return a + half_to_float(b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ float add(float a, __nv_bfloat16 b) -{ - return a + __bfloat162float(b); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 add(uint32_t a, float2 fb) -{ - float2 fa = half2_to_float2(a); - return add(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ add(uint2 a, Float4_ fb) -{ - Float4_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ add(uint4 a, Float8_ fb) -{ - Float8_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - fc.z = add(a.z, fb.z); - fc.w = add(a.w, fb.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t h0_h0(uint16_t a) -{ - uint32_t b; - asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); - return b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float fma(float a, float b, float c) -{ - return a * b + c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(float2 a, float2 b, float2 c) -{ - float2 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(float a, float2 b, float2 c) -{ - float2 d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 fma(float4 a, float4 b, float4 c) -{ - float4 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 fma(float a, float4 b, float4 c) -{ - float4 d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - d.z = fma(a, b.z, c.z); - d.w = fma(a, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) -{ - Float4_ d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) -{ - Float8_ d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - d.z = fma(a, b.z, c.z); - d.w = fma(a, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ float2 add(__nv_bfloat162 a, float2 fb) -{ - float2 fa = bf1622float2(a); - return add(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) -{ - Float4_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) -{ - Float8_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - fc.z = add(a.z, fb.z); - fc.w = add(a.w, fb.w); - return fc; -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) -{ - uint32_t d; - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) -{ - return fma(h0_h0(a), b, c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) -{ - uint2 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) -{ - uint32_t s = h0_h0(a); - uint2 d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) -{ - uint4 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) -{ - uint32_t s = h0_h0(a); - uint4 d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - d.z = fma(s, b.z, c.z); - d.w = fma(s, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float fma(uint16_t a, uint16_t b, float fc) -{ - float fa = half_to_float(a); - float fb = half_to_float(b); - return fa * fb + fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) -{ - float2 fa = half2_to_float2(a); - float2 fb = half2_to_float2(b); - return fma(fa, fb, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) -{ - return fma(h0_h0(a), b, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) -{ - Float4_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) -{ - uint32_t s = h0_h0(a); - Float4_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) -{ - Float8_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - fd.z = fma(a.z, b.z, fc.z); - fd.w = fma(a.w, b.w, fc.w); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) -{ - uint32_t s = h0_h0(a); - Float8_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - fd.z = fma(s, b.z, fc.z); - fd.w = fma(s, b.w, fc.w); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -#ifdef ENABLE_BF16 -inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ - return bf16hfma2(a, b, c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ - return bf16hfma2(bf162bf162(a), b, c); -} -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) -{ - bf16_4_t d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_4_t d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) -{ - bf16_8_t d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_8_t d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - d.z = fma(s, b.z, c.z); - d.w = fma(s, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) -{ - return __bfloat162float(a) * __bfloat162float(b) + fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) -{ - float2 fa = bf1622float2(a); - float2 fb = bf1622float2(b); - return fma(fa, fb, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) -{ - return fma(bf162bf162(a), b, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) -{ - Float4_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) -{ - __nv_bfloat162 s = bf162bf162(a); - Float4_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) -{ - Float8_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - fd.z = fma(a.z, b.z, fc.z); - fd.w = fma(a.w, b.w, fc.w); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) -{ - __nv_bfloat162 s = bf162bf162(a); - Float8_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - fd.z = fma(s, b.z, fc.z); - fd.w = fma(s, b.w, fc.w); - return fd; -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ Acc mul(A a, B b) -{ - return a * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(float a, float b) -{ - return a * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(float2 a, float2 b) -{ - float2 c; - c.x = a.x * b.x; - c.y = a.y * b.y; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(float a, float2 b) -{ - float2 c; - c.x = a * b.x; - c.y = a * b.y; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float4 mul(float4 a, float4 b) -{ - float4 c; - c.x = a.x * b.x; - c.y = a.y * b.y; - c.z = a.z * b.z; - c.w = a.w * b.w; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float4 mul(float a, float4 b) -{ - float4 c; - c.x = a * b.x; - c.y = a * b.y; - c.z = a * b.z; - c.w = a * b.w; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(float a, Float8_ b) -{ - Float8_ c; - c.x = make_float2(a * b.x.x, a * b.x.y); - c.y = make_float2(a * b.y.x, a * b.y.y); - c.z = make_float2(a * b.z.x, a * b.z.y); - c.w = make_float2(a * b.w.x, a * b.w.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint16_t mul(uint16_t a, uint16_t b) -{ - uint16_t c; - asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint32_t mul(uint32_t a, uint32_t b) -{ - uint32_t c; - asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint32_t mul(uint16_t a, uint32_t b) -{ - return mul(h0_h0(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint2 mul(uint2 a, uint2 b) -{ - uint2 c; - c.x = mul(a.x, b.x); - c.y = mul(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint2 mul(uint16_t a, uint2 b) -{ - uint32_t s = h0_h0(a); - uint2 c; - c.x = mul(s, b.x); - c.y = mul(s, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint4 mul(uint4 a, uint4 b) -{ - uint4 c; - c.x = mul(a.x, b.x); - c.y = mul(a.y, b.y); - c.z = mul(a.z, b.z); - c.w = mul(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint4 mul(uint16_t a, uint4 b) -{ - uint32_t s = h0_h0(a); - uint4 c; - c.x = mul(s, b.x); - c.y = mul(s, b.y); - c.z = mul(s, b.z); - c.w = mul(s, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(uint16_t a, uint16_t b) -{ - float fa = half_to_float(a); - float fb = half_to_float(b); - return fa * fb; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(uint16_t a, float b) -{ - return half_to_float(a) * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(uint32_t a, uint32_t b) -{ - float2 fa = half2_to_float2(a); - float2 fb = half2_to_float2(b); - return mul(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(uint16_t a, uint32_t b) -{ - return mul(h0_h0(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(uint2 a, uint2 b) -{ - Float4_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(uint16_t a, uint2 b) -{ - uint32_t s = h0_h0(a); - Float4_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(uint4 a, uint4 b) -{ - Float8_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - fc.z = mul(a.z, b.z); - fc.w = mul(a.w, b.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(uint16_t a, uint4 b) -{ - uint32_t s = h0_h0(a); - Float8_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - fc.z = mul(s, b.z); - fc.w = mul(s, b.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -template<> -inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - return __hmul(a, b); -#else - return bf16hmul(a, b); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) -{ - return bf16hmul2(a, b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) -{ - return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) -{ - bf16_4_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_4_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) -{ - bf16_8_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); - c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); - c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_8_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); - c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); - c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) -{ - float fa = (float)a; - float fb = (float)b; - return fa * fb; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(__nv_bfloat16 a, float b) -{ - return __bfloat162float(a) * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) -{ - float2 fa = bf1622float2(a); - float2 fb = bf1622float2(b); - return mul(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) -{ - return mul(bf162bf162(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) -{ - Float4_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - Float4_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) -{ - Float8_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - fc.z = mul(a.z, b.z); - fc.w = mul(a.w, b.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - Float8_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - fc.z = mul(s, b.z); - fc.w = mul(s, b.w); - return fc; -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(float v) -{ - return v; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(float2 v) -{ - return v.x + v.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(float4 v) -{ - return v.x + v.y + v.z + v.w; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ float sum(__nv_bfloat162 v) -{ - float2 vf = bf1622float2(v); - return vf.x + vf.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(bf16_4_t v) -{ - return sum(v.x) + sum(v.y); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(bf16_8_t v) -{ - return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint16_t v) -{ - return half_to_float(v); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint32_t v) -{ - float2 tmp = half2_to_float2(v); - return tmp.x + tmp.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint2 v) -{ - uint32_t c = add(v.x, v.y); - return sum(c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint4 v) -{ -#if 1 - uint32_t c = add(v.x, v.y); - c = add(c, v.z); - c = add(c, v.w); -#else - uint32_t c = add(v.x, v.y); - uint32_t d = add(v.z, v.w); - c = add(c, d); -#endif - return sum(c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(Float4_ v) -{ - return v.x.x + v.x.y + v.y.x + v.y.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(Float8_ v) -{ - return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float dot(T a, T b) -{ - return sum(mul(a, b)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float dot(T a, T b) -{ - return sum(mul(a, b)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void zero(uint16_t& dst) -{ - dst = uint16_t(0); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void zero(T& dst) -{ - constexpr int WORDS = sizeof(T) / 4; - union { - T raw; - uint32_t words[WORDS]; - } tmp; -#pragma unroll - for (int ii = 0; ii < WORDS; ++ii) { - tmp.words[ii] = 0u; - } - dst = tmp.raw; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const int t_step, const float base) -{ - const float pos_idx_inv_freq = t_step / pow(base, zid / (float)rot_embed_dim); - return {cos(pos_idx_inv_freq), sin(pos_idx_inv_freq)}; -} - -inline __device__ float2 rotary_embedding_transform(const float2 v, const float2 coef) -{ - float2 rot_v; - rot_v.x = coef.x * v.x - coef.y * v.y; - rot_v.y = coef.x * v.y + coef.y * v.x; - return rot_v; -} - -inline __device__ uint32_t rotary_embedding_transform(const uint32_t v, const float2 coef) -{ - float2 fv = half2_to_float2(v); - float2 rot_fv = rotary_embedding_transform(fv, coef); - return float2_to_half2(rot_fv); -} - -#ifdef ENABLE_BF16 -inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162 v, const float2 coef) -{ - float2 fv = bf1622float2(v); - float2 rot_fv = rotary_embedding_transform(fv, coef); - return __floats2bfloat162_rn(rot_fv.x, rot_fv.y); -} -#endif - -inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - return; -} - -inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - return; -} - -inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - - Float4_& q_ = *reinterpret_cast(&q); - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q_.x = rotary_embedding_transform(q_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q_.y = rotary_embedding_transform(q_.y, coef1); -} - -inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - - Float4_& q_ = *reinterpret_cast(&q); - Float4_& k_ = *reinterpret_cast(&k); - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q_.x = rotary_embedding_transform(q_.x, coef0); - k_.x = rotary_embedding_transform(k_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q_.y = rotary_embedding_transform(q_.y, coef1); - k_.y = rotary_embedding_transform(k_.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); - q.z = rotary_embedding_transform(q.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); - q.w = rotary_embedding_transform(q.w, coef3); -} - -inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); - q.z = rotary_embedding_transform(q.z, coef2); - k.z = rotary_embedding_transform(k.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); - q.w = rotary_embedding_transform(q.w, coef3); - k.w = rotary_embedding_transform(k.w, coef3); -} - -#ifdef ENABLE_BF16 -inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); -} - -inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); -} - -inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); - q.z = rotary_embedding_transform(q.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); - q.w = rotary_embedding_transform(q.w, coef3); -} - -inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); - q.z = rotary_embedding_transform(q.z, coef2); - k.z = rotary_embedding_transform(k.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); - q.w = rotary_embedding_transform(q.w, coef3); - k.w = rotary_embedding_transform(k.w, coef3); -} -#endif // ENABLE_BF16 - -template -inline __device__ float2 rotary_embedding_coefficient(const int zid, const int t_step, const T* rotary_cos, const T* rotary_sin) -{ - // zid is the index of the dimension (0, 2, 4, ..., rotary_dim). - // rotary_cos/sin stores those at index 0, 1, 2, ..., rotary_dim / 2. - return {float(rotary_cos[zid / 2]), float(rotary_sin[zid / 2])}; -} - -// fp16 is special because we use uint16_t for reading the data, for backward compatibility. -template <> -inline __device__ float2 rotary_embedding_coefficient(const int zid, const int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - // zid is the index of the dimension (0, 2, 4, ..., rotary_dim). - // rotary_cos/sin stores those at index 0, 1, 2, ..., rotary_dim / 2. - return {float(reinterpret_cast(rotary_cos)[zid / 2]), - float(reinterpret_cast(rotary_sin)[zid / 2])}; -} - -inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - return; -} - -inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - return; -} - -inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - - Float4_& q_ = *reinterpret_cast(&q); - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q_.x = rotary_embedding_transform(q_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q_.y = rotary_embedding_transform(q_.y, coef1); -} - -inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - - Float4_& q_ = *reinterpret_cast(&q); - Float4_& k_ = *reinterpret_cast(&k); - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q_.x = rotary_embedding_transform(q_.x, coef0); - k_.x = rotary_embedding_transform(k_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q_.y = rotary_embedding_transform(q_.y, coef1); - k_.y = rotary_embedding_transform(k_.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin); - q.z = rotary_embedding_transform(q.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin); - q.w = rotary_embedding_transform(q.w, coef3); -} - -inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin); - q.z = rotary_embedding_transform(q.z, coef2); - k.z = rotary_embedding_transform(k.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin); - q.w = rotary_embedding_transform(q.w, coef3); - k.w = rotary_embedding_transform(k.w, coef3); -} - -#ifdef ENABLE_BF16 -inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); -} - -inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); -} - -inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin); - q.z = rotary_embedding_transform(q.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin); - q.w = rotary_embedding_transform(q.w, coef3); -} - -inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin); - q.z = rotary_embedding_transform(q.z, coef2); - k.z = rotary_embedding_transform(k.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin); - q.w = rotary_embedding_transform(q.w, coef3); - k.w = rotary_embedding_transform(k.w, coef3); -} -#endif // ENABLE_BF16 - -template -__device__ __inline__ void vec_from_smem_transpose(Vec_T& vec, T* smem, int transpose_idx, int smem_pitch); - -template<> -__device__ __inline__ void vec_from_smem_transpose(float& vec, float* smem, int transpose_idx, int smem_pitch) -{ - return; -} - -template<> -__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; - tmp.u16[0] = smem[transpose_idx]; - tmp.u16[1] = smem[smem_pitch + transpose_idx]; - - vec = tmp.u32; -} - -template<> -__device__ __inline__ void vec_from_smem_transpose(uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp_1, tmp_2; - tmp_1.u32 = *reinterpret_cast(&smem[transpose_idx]); - tmp_2.u32 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); - - union { - uint2 u32x2; - uint16_t u16[4]; - } tmp_3; - tmp_3.u16[0] = tmp_1.u16[0]; - tmp_3.u16[1] = tmp_2.u16[0]; - tmp_3.u16[2] = tmp_1.u16[1]; - tmp_3.u16[3] = tmp_2.u16[1]; - - vec = tmp_3.u32x2; -} - -template<> -__device__ __inline__ void vec_from_smem_transpose(uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint64_t u64; - uint16_t u16[4]; - } tmp_1, tmp_2; - tmp_1.u64 = *reinterpret_cast(&smem[transpose_idx]); - tmp_2.u64 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); - - union { - uint4 u32x4; - uint16_t u16[8]; - } tmp_3; - tmp_3.u16[0] = tmp_1.u16[0]; - tmp_3.u16[1] = tmp_2.u16[0]; - tmp_3.u16[2] = tmp_1.u16[1]; - tmp_3.u16[3] = tmp_2.u16[1]; - tmp_3.u16[4] = tmp_1.u16[2]; - tmp_3.u16[5] = tmp_2.u16[2]; - tmp_3.u16[6] = tmp_1.u16[3]; - tmp_3.u16[7] = tmp_2.u16[3]; - - vec = tmp_3.u32x4; -} - -#ifdef ENABLE_BF16 -template<> -__device__ __inline__ void -vec_from_smem_transpose(bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - __nv_bfloat16 bf16[2]; - } tmp_1, tmp_2; - tmp_1.u32 = *reinterpret_cast(&smem[transpose_idx]); - tmp_2.u32 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); - - vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]}; - vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]}; -} - -template<> -__device__ __inline__ void -vec_from_smem_transpose(bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - union { - uint64_t u64; - __nv_bfloat16 bf16[4]; - } tmp_1, tmp_2; - tmp_1.u64 = *reinterpret_cast(&smem[transpose_idx]); - tmp_2.u64 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); - - vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]}; - vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]}; - vec.z = __nv_bfloat162{tmp_1.bf16[2], tmp_2.bf16[2]}; - vec.w = __nv_bfloat162{tmp_1.bf16[3], tmp_2.bf16[3]}; -} -#endif // ENABLE_BF16 - -template<> -__device__ __inline__ void vec_from_smem_transpose(float4& vec, float* smem, int transpose_idx, int smem_pitch) -{ - vec.x = smem[transpose_idx]; - vec.z = smem[transpose_idx + 1]; - vec.y = smem[smem_pitch + transpose_idx]; - vec.w = smem[smem_pitch + transpose_idx + 1]; -} - -template<> -__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, half* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - half u16[2]; - } tmp; - tmp.u16[0] = smem[transpose_idx]; - tmp.u16[1] = smem[smem_pitch + transpose_idx]; - - vec = tmp.u32; -} - -#ifdef ENABLE_BF16 -template<> -__device__ __inline__ void -vec_from_smem_transpose(__nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - vec.x = smem[transpose_idx]; - vec.y = smem[smem_pitch + transpose_idx]; -} -#endif - -template<> -__device__ __inline__ void vec_from_smem_transpose(float2& vec, float* smem, int transpose_idx, int smem_pitch) -{ - vec.x = smem[transpose_idx]; - vec.y = smem[smem_pitch + transpose_idx]; -} - -template -__device__ __inline__ void write_smem_transpose(const Vec_T& vec, T* smem, int transpose_idx, int smem_pitch); - -template<> -__device__ __inline__ void write_smem_transpose(const float& vec, float* smem, int transpose_idx, int smem_pitch) -{ - return; -} - -template<> -__device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint64_t u64; - uint16_t u16[4]; - } tmp_1, tmp_2; - - union { - uint4 u32x4; - uint16_t u16[8]; - } tmp_3; - tmp_3.u32x4 = vec; - tmp_1.u16[0] = tmp_3.u16[0]; - tmp_2.u16[0] = tmp_3.u16[1]; - tmp_1.u16[1] = tmp_3.u16[2]; - tmp_2.u16[1] = tmp_3.u16[3]; - tmp_1.u16[2] = tmp_3.u16[4]; - tmp_2.u16[2] = tmp_3.u16[5]; - tmp_1.u16[3] = tmp_3.u16[6]; - tmp_2.u16[3] = tmp_3.u16[7]; - - *reinterpret_cast(&smem[transpose_idx]) = tmp_1.u64; - *reinterpret_cast(&smem[smem_pitch + transpose_idx]) = tmp_2.u64; -} - -template<> -__device__ __inline__ void write_smem_transpose(const uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp_1, tmp_2; - - union { - uint2 u32x2; - uint16_t u16[4]; - } tmp_3; - tmp_3.u32x2 = vec; - tmp_1.u16[0] = tmp_3.u16[0]; - tmp_2.u16[0] = tmp_3.u16[1]; - tmp_1.u16[1] = tmp_3.u16[2]; - tmp_2.u16[1] = tmp_3.u16[3]; - - *reinterpret_cast(&smem[transpose_idx]) = tmp_1.u32; - *reinterpret_cast(&smem[smem_pitch + transpose_idx]) = tmp_2.u32; -} - -template<> -__device__ __inline__ void write_smem_transpose(const uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; - tmp.u32 = vec; - - smem[transpose_idx] = tmp.u16[0]; - smem[smem_pitch + transpose_idx] = tmp.u16[1]; -} - -template<> -__device__ __inline__ void write_smem_transpose(const float4& vec, float* smem, int transpose_idx, int smem_pitch) -{ - smem[transpose_idx] = vec.x; - smem[transpose_idx + 1] = vec.z; - smem[smem_pitch + transpose_idx] = vec.y; - smem[smem_pitch + transpose_idx + 1] = vec.w; -} - -template<> -__device__ __inline__ void write_smem_transpose(const uint32_t& vec, half* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - half u16[2]; - } tmp; - - tmp.u32 = vec; - smem[transpose_idx] = tmp.u16[0]; - smem[smem_pitch + transpose_idx] = tmp.u16[1]; -} - -#ifdef ENABLE_BF16 -template<> -__device__ __inline__ void -write_smem_transpose(const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - smem[transpose_idx] = vec.x; - smem[smem_pitch + transpose_idx] = vec.y; -} - -template<> -__device__ __inline__ void -write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - write_smem_transpose(reinterpret_cast(vec), reinterpret_cast(smem), transpose_idx, smem_pitch); -} - -template<> -__device__ __inline__ void -write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - write_smem_transpose(reinterpret_cast(vec), reinterpret_cast(smem), transpose_idx, smem_pitch); -} -#endif - -template<> -__device__ __inline__ void write_smem_transpose(const float2& vec, float* smem, int transpose_idx, int smem_pitch) -{ - smem[transpose_idx] = vec.x; - smem[smem_pitch + transpose_idx] = vec.y; -} - -} // namespace mmha diff --git a/csrc/ft_attention/ft_attention.cpp b/csrc/ft_attention/ft_attention.cpp deleted file mode 100644 index 886da9729ba..00000000000 --- a/csrc/ft_attention/ft_attention.cpp +++ /dev/null @@ -1,231 +0,0 @@ -#include -#include "ATen/cuda/CUDAContext.h" -#include - - -#include "decoder_masked_multihead_attention.h" - -#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA") -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") - -#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...) \ - if (TYPE == at::ScalarType::Half) { \ - using scalar_t = at::Half; \ - __VA_ARGS__(); \ - } else if (TYPE == at::ScalarType::BFloat16) { \ - using scalar_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (TYPE == at::ScalarType::Float) { \ - using scalar_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \ - } - -template -void masked_multihead_attention(const Masked_multihead_attention_params& params, - const cudaStream_t& stream); - -template -void cross_multihead_attention(const Masked_multihead_attention_params& params, - const cudaStream_t& stream); - -template -struct SATypeConverter { - using Type = T; -}; - -template<> -struct SATypeConverter { - using Type = uint16_t; -}; - -template<> -struct SATypeConverter { - using Type = __nv_bfloat16; -}; - -template -void set_params(Masked_multihead_attention_params ¶ms, - const size_t batch_size, - const size_t nheads, - const size_t nheads_kv, - const size_t memory_max_seqlen, - const size_t headdim, - const int timestep, - const int rotary_embedding_dim, - const float rotary_base, - const bool neox_rotary_style, - const int q_batch_stride, - const int k_batch_stride, - const int v_batch_stride, - const int nnz_heads, - T *q_ptr, - T *k_ptr, - T *v_ptr, - T *k_cache_ptr, - T *v_cache_ptr, - int *length_per_sample, - T *rotary_cos, - T *rotary_sin, - T *out_ptr, - int *nnz_head_idx) { - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - params.q = q_ptr; - params.k = k_ptr; - params.v = v_ptr; - params.q_bias = nullptr; - params.k_bias = nullptr; - params.v_bias = nullptr; - params.k_cache = k_cache_ptr; - params.v_cache = v_cache_ptr; - params.out = out_ptr; - params.cache_indir = nullptr; - params.stride_q = q_batch_stride; - params.stride_k = k_batch_stride; - params.stride_v = v_batch_stride; - params.batch_size = batch_size; - params.beam_width = 1; - params.memory_max_len = memory_max_seqlen; - params.num_heads = nheads; - params.num_heads_kv = nheads_kv; - params.num_heads_q_kv_ratio = nheads / nheads_kv; - params.nnz_heads = nnz_heads; - params.hidden_size_per_head = headdim; - params.rotary_embedding_dim = rotary_embedding_dim; - params.rotary_base = rotary_base; - params.neox_rotary_style = neox_rotary_style; - params.timestep = timestep; - params.inv_sqrt_dh = 1.f / sqrt(float(headdim)); - params.total_padding_tokens = nullptr; - params.masked_tokens = nullptr; - params.prefix_prompt_lengths = nullptr; - params.max_prefix_prompt_length = 0; - params.relative_attention_bias = nullptr; - params.relative_attention_bias_stride = 0; - params.cross_attention_out = nullptr; - params.max_decoder_seq_len = 0; - params.is_return_cross_attentions = false; - params.finished = nullptr; - params.memory_length_per_sample = nullptr; - params.length_per_sample = length_per_sample; - params.rotary_cos = rotary_cos; - params.rotary_sin = rotary_sin; - params.nnz_head_idx = nnz_head_idx; -} - -torch::Tensor single_query_attention(const torch::Tensor q, - const torch::Tensor k, - const torch::Tensor v, - torch::Tensor k_cache, - torch::Tensor v_cache, - std::optional length_per_sample_, - std::optional rotary_cos_, - std::optional rotary_sin_, - std::optional nnz_head_idx_, - const int timestep, - int rotary_embedding_dim = 0, - const float rotary_base = 10000.0f, - const bool neox_rotary_style=true) { - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache); - int batch_size = v_cache.size(0); - int nheads = q.size(1); - int nheads_kv = v_cache.size(1); - int memory_max_seqlen = v_cache.size(2); - int headdim = v_cache.size(3); - auto input_type = q.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); - - CHECK_SHAPE(q, batch_size, nheads, headdim); - CHECK_SHAPE(k, batch_size, nheads_kv, headdim); - CHECK_SHAPE(v, batch_size, nheads_kv, headdim); - CHECK_SHAPE(v_cache, batch_size, nheads_kv, memory_max_seqlen, headdim); - // k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32 - int packsize = k_cache.dtype() == torch::kFloat32 ? 4 : 8; - CHECK_SHAPE(k_cache, batch_size, nheads_kv, headdim / packsize, memory_max_seqlen, packsize); - TORCH_CHECK(q.stride(2) == 1 && q.stride(1) == headdim); - TORCH_CHECK(k.stride(2) == 1 && k.stride(1) == headdim); - TORCH_CHECK(v.stride(2) == 1 && v.stride(1) == headdim); - CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache); - - TORCH_CHECK(q.scalar_type() == input_type); - TORCH_CHECK(k.scalar_type() == input_type); - TORCH_CHECK(v.scalar_type() == input_type); - TORCH_CHECK(k_cache.scalar_type() == input_type); - TORCH_CHECK(v_cache.scalar_type() == input_type); - - if (length_per_sample_.has_value()) { - auto length_per_sample = length_per_sample_.value(); - CHECK_DEVICE(length_per_sample); - CHECK_SHAPE(length_per_sample, batch_size); - CHECK_CONTIGUOUS(length_per_sample); - TORCH_CHECK(length_per_sample.dtype() == torch::kInt32); - } - - if (rotary_cos_.has_value()) { - auto rotary_cos = rotary_cos_.value(); - CHECK_DEVICE(rotary_cos); - rotary_embedding_dim = rotary_cos.size(-1) * 2; - CHECK_SHAPE(rotary_cos, batch_size, rotary_embedding_dim / 2); - CHECK_CONTIGUOUS(rotary_cos); - TORCH_CHECK(rotary_cos.scalar_type() == input_type); - - TORCH_CHECK(rotary_sin_.has_value()); - auto rotary_sin = rotary_sin_.value(); - CHECK_DEVICE(rotary_sin); - CHECK_SHAPE(rotary_sin, batch_size, rotary_embedding_dim / 2); - CHECK_CONTIGUOUS(rotary_sin); - TORCH_CHECK(rotary_sin.scalar_type() == input_type); - } - - if (nnz_head_idx_.has_value()) { - auto nnz_head_idx = nnz_head_idx_.value(); - CHECK_DEVICE(nnz_head_idx); - int nnz_heads = nnz_head_idx.size(0); - CHECK_SHAPE(nnz_head_idx, nnz_heads); - CHECK_CONTIGUOUS(nnz_head_idx); - TORCH_CHECK(nnz_head_idx.dtype() == torch::kInt32); - } - - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; - - torch::Tensor out = torch::empty_like(q); - - DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), "single_query_attention", [&] { - using DataType = typename SATypeConverter::Type; - Masked_multihead_attention_params params; - set_params(params, batch_size, nheads, nheads_kv, memory_max_seqlen, headdim, timestep, - rotary_embedding_dim, rotary_base, neox_rotary_style, - q.stride(0), k.stride(0), v.stride(0), - nnz_head_idx_.has_value() ? nnz_head_idx_.value().size(0) : 0, - reinterpret_cast(q.data_ptr()), - reinterpret_cast(k.data_ptr()), - reinterpret_cast(v.data_ptr()), - reinterpret_cast(k_cache.data_ptr()), - reinterpret_cast(v_cache.data_ptr()), - length_per_sample_.has_value() - ? length_per_sample_.value().data_ptr() : nullptr, - rotary_cos_.has_value() - ? reinterpret_cast(rotary_cos_.value().data_ptr()) : nullptr, - rotary_sin_.has_value() - ? reinterpret_cast(rotary_sin_.value().data_ptr()) : nullptr, - reinterpret_cast(out.data_ptr()), - nnz_head_idx_.has_value() ? nnz_head_idx_.value().data_ptr() : nullptr - ); - auto stream = at::cuda::getCurrentCUDAStream(); - masked_multihead_attention(params, stream); - }); - return out; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("single_query_attention", &single_query_attention, "Attention with a single query", - py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"), - py::arg("length_per_sample_"), py::arg("rotary_cos_"), - py::arg("rotary_sin_"), py::arg("nnz_head_idx_"), - py::arg("timestep"), py::arg("rotary_embedding_dim")=0, - py::arg("rotary_base")=10000.0f, py::arg("neox_rotary_style")=true); -} diff --git a/csrc/ft_attention/setup.py b/csrc/ft_attention/setup.py deleted file mode 100644 index fa385ad768c..00000000000 --- a/csrc/ft_attention/setup.py +++ /dev/null @@ -1,153 +0,0 @@ -# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py -import sys -import warnings -import os -from packaging.version import parse, Version - -from setuptools import setup, find_packages -import subprocess - -import torch -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME - - -# ninja build does not work unless include_dirs are abs path -this_dir = os.path.dirname(os.path.abspath(__file__)) - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version - - -def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) - torch_binary_version = parse(torch.version.cuda) - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_version != torch_binary_version): - raise RuntimeError( - "Cuda extensions are being compiled with a version of Cuda that does " - "not match the version used to compile Pytorch binaries. " - "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) - + "In some cases, a minor-version mismatch will not cause later errors: " - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk)." - ) - - -def raise_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - raise RuntimeError( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " - "only images whose names contain 'devel' will provide nvcc." - ) - - -def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.2"): - nvcc_threads = os.getenv("NVCC_THREADS") or "4" - return nvcc_extra_args + ["--threads", nvcc_threads] - return nvcc_extra_args - - -if not torch.cuda.is_available(): - # https://github.com/NVIDIA/apex/issues/486 - # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), - # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). - print( - "\nWarning: Torch did not find available GPUs on this system.\n", - "If your intention is to cross-compile, this is not an error.\n" - "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" - "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" - "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" - "If you wish to cross-compile for a single specific architecture,\n" - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', - ) - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.8"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" - elif bare_metal_version >= Version("11.1"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" - elif bare_metal_version == Version("11.0"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" - else: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" - - -print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) - -cmdclass = {} -ext_modules = [] - -# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h -# See https://github.com/pytorch/pytorch/pull/70650 -generator_flag = [] -torch_dir = torch.__path__[0] -if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): - generator_flag = ["-DOLD_GENERATOR_PATH"] - -raise_if_cuda_home_none("--ft_attention") -# Check, if CUDA11 is installed for compute capability 8.0 -cc_flag = [] -_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) -if bare_metal_version < Version("11.0"): - raise RuntimeError("ft_attention is only supported on CUDA 11 and above") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_70,code=sm_70") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_80,code=sm_80") -if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - -ext_modules.append( - CUDAExtension( - name="ft_attention", - sources=[ - "ft_attention.cpp", - "decoder_masked_multihead_attention.cu", - ], - extra_compile_args={ - "cxx": ["-O3", "-DENABLE_BF16"] + generator_flag, - "nvcc": append_nvcc_threads( - [ - "-DENABLE_BF16", # TODO - "-O3", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - ] - + generator_flag - + cc_flag - ), - }, - include_dirs=[this_dir], - ) -) - -setup( - name="ft_attention", - version="0.1", - description="Attention for single query from FasterTransformer", - ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension} if ext_modules else {}, -) diff --git a/csrc/fused_softmax/fused_softmax.cpp b/csrc/fused_softmax/fused_softmax.cpp deleted file mode 100644 index 2aaed913314..00000000000 --- a/csrc/fused_softmax/fused_softmax.cpp +++ /dev/null @@ -1,148 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_masked_softmax { - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor); - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor); - -int get_batch_per_block_cuda( - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads); - -torch::Tensor fwd( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor) { - AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); - - return fwd_cuda(input, mask, scale_factor); -} - -torch::Tensor bwd( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor) { - - AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); - - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - return bwd_cuda(output_grads, softmax_results, scale_factor); -} - -int get_batch_per_block( - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads) { - return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); -} - -} // end namespace scaled_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_upper_triang_masked_softmax { - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - float scale_factor); - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor); - -torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { - AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - return fwd_cuda(input, scale_factor); -} - -torch::Tensor bwd( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor) { - - AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); - - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - return bwd_cuda(output_grads, softmax_results, scale_factor); -} - -} // end namespace scaled_upper_triang_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("scaled_masked_softmax_forward", - &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward."); - - m.def("scaled_masked_softmax_backward", - &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward."); - - m.def("scaled_masked_softmax_get_batch_per_block", - &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, - "Return Batch per block size." - ); - - m.def("scaled_upper_triang_masked_softmax_forward", - &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward."); - m.def("scaled_upper_triang_masked_softmax_backward", - &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward."); -} diff --git a/csrc/fused_softmax/scaled_masked_softmax.h b/csrc/fused_softmax/scaled_masked_softmax.h deleted file mode 100644 index 14b9f6e4242..00000000000 --- a/csrc/fused_softmax/scaled_masked_softmax.h +++ /dev/null @@ -1,528 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -namespace { - -template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } - -int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - -template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) -{ -#if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); -#else - return __shfl_xor(value, laneMask, width); -#endif -} - -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; - #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } - } -} - -/* - * Extended softmax (from native aten pytorch) with following additional features - * 1) input scaling - * 2) Explicit masking - */ -template -__global__ void scaled_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const acc_t scale, - int micro_batch_size, - int element_count, - int pad_batches) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; - int pad_first_batch = 0; - if (pad_batches != 1) { // bert style - pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; - } else { // gpt2 style - pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - } - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - int itr_idx = i*element_count+it*WARP_SIZE; - copy_vector(temp_data, src + itr_idx); - copy_vector(temp_mask, mask + itr_idx); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (temp_mask[element] != 1) { - elements[i][it + element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -10000.0; - } - } - } else { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } - } - - // compute max_value - acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - warp_reduce(max_value); - - // compute scale value to account for full mask - acc_t scale_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - scale_value[i] = (max_value[i] == -10000.0) ? 0.0 : 1.0; - } - - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = elements[i][it + element] * scale_value[i]/ sum[i]; - } - copy_vector(dst + i * element_count + it * WARP_SIZE, out); - } else { - break; - } - } - } -} - -template -__global__ void scaled_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count + it * WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; - } - } - } - } - - acc_t sum[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } - } - warp_reduce(sum); - - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); - } - copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); - } - } - } -} -} // end of anonymous namespace - -int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - constexpr int threads_per_block = 128; - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - - return batches_per_block; -} - -template -void dispatch_scaled_masked_softmax_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const input_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads, - int pad_batches) -{ - TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 8192 ); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); - dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 1: // 2 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 2: // 4 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 3: // 8 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 4: // 16 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 5: // 32 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 6: // 64 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 7: // 128 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 8: // 256 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 9: // 512 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 10: // 1024 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 11: // 2048 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 12: // 4096 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 13: // 8192 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - default: - break; - } - } -} - -template -void dispatch_scaled_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads) -{ - TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 8192 ); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = batch_count/batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 1: // 2 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 2: // 4 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 3: // 8 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 4: // 16 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 5: // 32 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 6: // 64 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 7: // 128 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 8: // 256 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 9: // 512 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 10: // 1024 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 11: // 2048 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 12: // 4096 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 13: // 8192 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - default: - break; - } - } -} diff --git a/csrc/fused_softmax/scaled_masked_softmax_cuda.cu b/csrc/fused_softmax/scaled_masked_softmax_cuda.cu deleted file mode 100644 index a08e752699c..00000000000 --- a/csrc/fused_softmax/scaled_masked_softmax_cuda.cu +++ /dev/null @@ -1,121 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include "scaled_masked_softmax.h" -#include "type_shim.h" - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_masked_softmax { - -int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ - return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); -} - - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor) -{ - // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] - const int batches = input.size(0); - const int pad_batches = mask.size(0); - const int attn_heads = input.size(1); - const int query_seq_len = input.size(2); - const int key_seq_len = input.size(3); - TORCH_INTERNAL_ASSERT(key_seq_len <= 8192); - TORCH_INTERNAL_ASSERT(query_seq_len > 1); - TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); - TORCH_INTERNAL_ASSERT(mask.size(1) == 1); - TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); - TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); - - // Output - auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); - - // Softmax Intermediate Result Ptr - void* input_ptr = static_cast(input.data_ptr()); - void* mask_ptr = static_cast(mask.data_ptr()); - void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); - - DISPATCH_HALF_AND_BFLOAT( - input.scalar_type(), - "dispatch_scaled_masked_softmax_forward", - dispatch_scaled_masked_softmax_forward( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), - reinterpret_cast(mask_ptr), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads, - pad_batches - ); - ); - return softmax_results; -} - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, - float scale_factor) { - - auto output_grads = output_grads_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); - - //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] - const int batches = output_grads.size(0); - const int attn_heads = output_grads.size(1); - const int query_seq_len = output_grads.size(2); - const int key_seq_len = output_grads.size(3); - - auto act_options = output_grads.options().requires_grad(false); - torch::Tensor input_grads = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); - void* input_grads_ptr = static_cast(input_grads.data_ptr()); - void* output_grads_ptr = static_cast(output_grads.data_ptr()); - - //Softmax Grad - DISPATCH_HALF_AND_BFLOAT( - output_grads_.scalar_type(), - "dispatch_scaled_masked_softmax_backward", - dispatch_scaled_masked_softmax_backward( - reinterpret_cast(input_grads_ptr), - reinterpret_cast(output_grads_ptr), - reinterpret_cast(softmax_results.data_ptr()), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads - ); - ); - return input_grads; -} -} -} -} diff --git a/csrc/fused_softmax/scaled_upper_triang_masked_softmax.h b/csrc/fused_softmax/scaled_upper_triang_masked_softmax.h deleted file mode 100644 index 21e93fb313a..00000000000 --- a/csrc/fused_softmax/scaled_upper_triang_masked_softmax.h +++ /dev/null @@ -1,529 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace { - -template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } - -template -__device__ __inline__ void copy_zero_vector(Datatype *dst); - -template <> -__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } - -template <> -__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } - -template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } - -template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } - - -int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - -template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) -{ -#if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); -#else - return __shfl_xor(value, laneMask, width); -#endif -} - -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; - #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } - } -} - -/* - * Extended softmax (from native aten pytorch) with following additional features - * 1) input scaling - * 2) Implicit time (diagonal masking) - */ -template -__global__ void scaled_upper_triang_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const acc_t scale, - int micro_batch_size, - int stride, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - copy_vector(temp_data, src + i*element_count*stride + it*WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if ((element_index + element) < batch_element_count) { - elements[i][it+element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } else { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } - } - - // compute max_value - acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - warp_reduce(max_value); - - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - if (it < warp_iteration_limit) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } - } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < local_seq) { - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < local_seq) { - out[element] = elements[i][it + element] / sum[i]; - } else { - out[element] = 0; - } - } - copy_vector(dst + i * element_count * stride + it * WARP_SIZE, out); - } else if (element_index < element_count) { - copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); - } else { - break; - } - } - } -} - -template -__global__ void scaled_upper_triang_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int stride, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count * stride + it * WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } - } - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; - } - } - } - } - } - - acc_t sum[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } - } - warp_reduce(sum); - - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); - } - copy_vector(gradInput + i * element_count * stride + it * WARP_SIZE, out); - } - } - } -} - -} // end of anonymous namespace - -template -void dispatch_scaled_upper_triang_masked_softmax_forward( - output_t *dst, - const input_t *src, - const input_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches) -{ - TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 8192 ); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 12: // 4096 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 13: // 8192 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - default: - break; - } - } -} - -template -void dispatch_scaled_upper_triang_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches) -{ - TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 8192 ); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 12: // 4096 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 13: // 8192 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - default: - break; - } - } -} diff --git a/csrc/fused_softmax/scaled_upper_triang_masked_softmax_cuda.cu b/csrc/fused_softmax/scaled_upper_triang_masked_softmax_cuda.cu deleted file mode 100644 index 79ec30be364..00000000000 --- a/csrc/fused_softmax/scaled_upper_triang_masked_softmax_cuda.cu +++ /dev/null @@ -1,98 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include "scaled_upper_triang_masked_softmax.h" -#include "type_shim.h" - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_upper_triang_masked_softmax { - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - float scale_factor) -{ - // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] - const int attn_batches = input.size(0); - const int seq_len = input.size(1); - TORCH_INTERNAL_ASSERT(seq_len <= 8192); - - // Output - auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = - torch::empty({attn_batches, seq_len, seq_len}, act_options); - - // Softmax Intermediate Result Ptr - void* input_ptr = static_cast(input.data_ptr()); - void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); - - DISPATCH_HALF_AND_BFLOAT( - input.scalar_type(), - "dispatch_scaled_upper_triang_masked_softmax_forward", - dispatch_scaled_upper_triang_masked_softmax_forward( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), - scale_factor, - seq_len, - seq_len, - attn_batches); - ); - return softmax_results; -} - - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, - float scale_factor) { - - auto output_grads = output_grads_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); - - //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] - const int attn_batches = output_grads.size(0); - const int seq_len = output_grads.size(1); - TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); - - void* output_grads_ptr = static_cast(output_grads.data_ptr()); - - //Softmax Grad - DISPATCH_HALF_AND_BFLOAT( - output_grads_.scalar_type(), - "dispatch_scaled_upper_triang_masked_softmax_backward", - dispatch_scaled_upper_triang_masked_softmax_backward( - reinterpret_cast(output_grads_ptr), - reinterpret_cast(output_grads_ptr), - reinterpret_cast(softmax_results.data_ptr()), - scale_factor, - seq_len, - seq_len, - attn_batches); - ); - - //backward pass is completely in-place - return output_grads; -} -} -} -} diff --git a/csrc/fused_softmax/setup.py b/csrc/fused_softmax/setup.py deleted file mode 100644 index 9c1c6ed76e9..00000000000 --- a/csrc/fused_softmax/setup.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copied from https://github.com/NVIDIA/apex/tree/master/csrc/megatron -# We add the case where seqlen = 4k and seqlen = 8k -import os -import subprocess - -import torch -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] - - return raw_output, bare_metal_major, bare_metal_minor - - -def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: - nvcc_threads = os.getenv("NVCC_THREADS") or "4" - return nvcc_extra_args + ["--threads", nvcc_threads] - return nvcc_extra_args - - -cc_flag = [] -cc_flag.append("-gencode") -cc_flag.append("arch=compute_70,code=sm_70") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_80,code=sm_80") - -setup( - name='fused_softmax_lib', - ext_modules=[ - CUDAExtension( - name='fused_softmax_lib', - sources=['fused_softmax.cpp', 'scaled_masked_softmax_cuda.cu', 'scaled_upper_triang_masked_softmax_cuda.cu'], - extra_compile_args={ - 'cxx': ['-O3',], - 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag) - } - ) - ], - cmdclass={ - 'build_ext': BuildExtension -}) diff --git a/csrc/fused_softmax/type_shim.h b/csrc/fused_softmax/type_shim.h deleted file mode 100644 index 815ec7ec889..00000000000 --- a/csrc/fused_softmax/type_shim.h +++ /dev/null @@ -1,20 +0,0 @@ -#include - -#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ -switch(TYPE) \ -{ \ -case at::ScalarType::Half: \ - { \ -using scalar_t = at::Half; \ -__VA_ARGS__; \ -break; \ - } \ -case at::ScalarType::BFloat16: \ - { \ -using scalar_t = at::BFloat16; \ -__VA_ARGS__; \ -break; \ - } \ -default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ -} diff --git a/csrc/rotary/rotary.cpp b/csrc/rotary/rotary.cpp deleted file mode 100644 index 640eea423ac..00000000000 --- a/csrc/rotary/rotary.cpp +++ /dev/null @@ -1,40 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#include -#include - -#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA") -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") - -void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2, - const torch::Tensor cos, const torch::Tensor sin, - torch::Tensor out1, torch::Tensor out2, - const bool conj); - -void apply_rotary(const torch::Tensor x1, const torch::Tensor x2, - const torch::Tensor cos, const torch::Tensor sin, - torch::Tensor out1, torch::Tensor out2, - const bool conj) { - CHECK_DEVICE(x1); CHECK_DEVICE(x2); - CHECK_DEVICE(cos); CHECK_DEVICE(sin); - CHECK_DEVICE(out1); CHECK_DEVICE(out1); - TORCH_CHECK(x1.dtype() == x2.dtype()); - TORCH_CHECK(cos.dtype() == sin.dtype()); - TORCH_CHECK(out1.dtype() == out2.dtype()); - TORCH_CHECK(x1.dtype() == cos.dtype()); - TORCH_CHECK(x1.dtype() == out1.dtype()); - TORCH_CHECK(x1.sizes() == x2.sizes()); - TORCH_CHECK(cos.sizes() == sin.sizes()); - TORCH_CHECK(out1.sizes() == out2.sizes()); - - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{x1.device()}; - - apply_rotary_cuda(x1, x2, cos, sin, out1, out2, conj); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("apply_rotary", &apply_rotary, "Apply rotary embedding"); -} diff --git a/csrc/rotary/rotary_cuda.cu b/csrc/rotary/rotary_cuda.cu deleted file mode 100644 index 2dd0ff3f6e2..00000000000 --- a/csrc/rotary/rotary_cuda.cu +++ /dev/null @@ -1,45 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#include -#include -#include - -void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2, - const torch::Tensor cos, const torch::Tensor sin, - torch::Tensor out1, torch::Tensor out2, - const bool conj) { - auto iter = at::TensorIteratorConfig() - .add_output(out1) - .add_output(out2) - .add_input(x1) - .add_input(x2) - .add_input(cos) - .add_input(sin) - .check_all_same_dtype(false) - .promote_inputs_to_common_dtype(false) - .build(); - - if (!conj) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] { - at::native::gpu_kernel_multiple_outputs( - iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos, - scalar_t sin) -> thrust::tuple { - scalar_t out1 = float(x1) * float(cos) - float(x2) * float(sin); - scalar_t out2 = float(x1) * float(sin) + float(x2) * float(cos); - return {out1, out2}; - }); - }); - } else { - AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] { - at::native::gpu_kernel_multiple_outputs( - iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos, - scalar_t sin) -> thrust::tuple { - scalar_t out1 = float(x1) * float(cos) + float(x2) * float(sin); - scalar_t out2 = -float(x1) * float(sin) + float(x2) * float(cos); - return {out1, out2}; - }); - }); - } -} \ No newline at end of file diff --git a/csrc/rotary/setup.py b/csrc/rotary/setup.py deleted file mode 100644 index 24d328d9c6a..00000000000 --- a/csrc/rotary/setup.py +++ /dev/null @@ -1,126 +0,0 @@ -# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py -import sys -import warnings -import os -from packaging.version import parse, Version - -import torch -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME -from setuptools import setup, find_packages -import subprocess - -# ninja build does not work unless include_dirs are abs path -this_dir = os.path.dirname(os.path.abspath(__file__)) - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version - - -def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) - torch_binary_version = parse(torch.version.cuda) - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_version != torch_binary_version): - raise RuntimeError( - "Cuda extensions are being compiled with a version of Cuda that does " - "not match the version used to compile Pytorch binaries. " - "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) - + "In some cases, a minor-version mismatch will not cause later errors: " - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk)." - ) - - -def raise_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - raise RuntimeError( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " - "only images whose names contain 'devel' will provide nvcc." - ) - - -def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.2"): - nvcc_threads = os.getenv("NVCC_THREADS") or "4" - return nvcc_extra_args + ["--threads", nvcc_threads] - return nvcc_extra_args - - -if not torch.cuda.is_available(): - # https://github.com/NVIDIA/apex/issues/486 - # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), - # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). - print( - "\nWarning: Torch did not find available GPUs on this system.\n", - "If your intention is to cross-compile, this is not an error.\n" - "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" - "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" - "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" - "If you wish to cross-compile for a single specific architecture,\n" - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', - ) - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.8"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" - elif bare_metal_version >= Version("11.1"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" - elif bare_metal_version == Version("11.0"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" - else: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" - - -print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) - -cmdclass = {} -ext_modules = [] - -raise_if_cuda_home_none("rotary_emb") -# Check, if CUDA11 is installed for compute capability 8.0 -cc_flag = [] -_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) -if bare_metal_version < Version("11.0"): - raise RuntimeError("rotary_emb is only supported on CUDA 11 and above") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_70,code=sm_70") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_80,code=sm_80") -if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - -ext_modules.append( - CUDAExtension( - 'rotary_emb', [ - 'rotary.cpp', - 'rotary_cuda.cu', - ], - extra_compile_args={'cxx': ['-g', '-march=native', '-funroll-loops'], - 'nvcc': append_nvcc_threads([ - '-O3', '--use_fast_math', '--expt-extended-lambda' - ] + cc_flag) - } - ) -) - -setup( - name="rotary_emb", - version="0.1", - ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension} if ext_modules else {}, -) diff --git a/csrc/xentropy/README.md b/csrc/xentropy/README.md deleted file mode 100644 index 1bc90fdab77..00000000000 --- a/csrc/xentropy/README.md +++ /dev/null @@ -1,14 +0,0 @@ -This CUDA extension implements optimized cross-entropy loss, adapted from Apex's -[Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy). -We make it work for bfloat16 and support in-place backward to save memory. - -It has only been tested on A100s. - -```sh -cd csrc/xentropy && pip install . -``` - -As of 2023-09-15, this extension is no longer used in the FlashAttention repo. -We've instead switched to a Triton-based -[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/cross_entropy.py). -See the CrossEntropyLoss [module](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/losses/cross_entropy.py) for more details. diff --git a/csrc/xentropy/interface.cpp b/csrc/xentropy/interface.cpp deleted file mode 100644 index 41a783fd0fc..00000000000 --- a/csrc/xentropy/interface.cpp +++ /dev/null @@ -1,59 +0,0 @@ -#include - -// CUDA forward declarations -std::vector softmax_xentropy_cuda( - const at::Tensor &input, - const at::Tensor &labels, - const float smoothing, - const int total_classes); - -at::Tensor softmax_xentropy_backward_cuda( - const at::Tensor &grad_loss, - at::Tensor &logits, - const at::Tensor &max_log_sum_exp, - const at::Tensor &labels, - const float smoothing, - const bool inplace, - const int total_classes); - -// C++ interface - -#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -std::vector softmax_xentropy_forward( - const at::Tensor &input, - const at::Tensor &labels, - const float smoothing, - const int total_classes=-1) { - // For tensor parallel cross entropy with smoothing, we want to pass in the total number - // of classes so that smoothing can be applied correctly. If total_classes=-1, use the - // last dimension of the input tensor. - CHECK_INPUT(input); - CHECK_INPUT(labels); - - return softmax_xentropy_cuda(input, labels, smoothing, total_classes); -} - -at::Tensor softmax_xentropy_backward( - const at::Tensor &grad_loss, - at::Tensor &logits, - const at::Tensor &max_log_sum_exp, - const at::Tensor &labels, - const float smoothing, - const bool inplace, - const int total_classes=-1) { - CHECK_INPUT(grad_loss); - CHECK_INPUT(logits); - CHECK_INPUT(max_log_sum_exp); - CHECK_INPUT(labels); - - return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, - smoothing, inplace, total_classes); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)", py::arg("input"), py::arg("labels"), py::arg("smoothing"), py::arg("total_classes")=-1); - m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)", py::arg("grad_loss"), py::arg("logits"), py::arg("max_log_sum_exp"), py::arg("labels"), py::arg("smoothing"), py::arg("inplace"), py::arg("total_classes")=-1); -} diff --git a/csrc/xentropy/setup.py b/csrc/xentropy/setup.py deleted file mode 100644 index 5079b4f3847..00000000000 --- a/csrc/xentropy/setup.py +++ /dev/null @@ -1,139 +0,0 @@ -# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py -import sys -import warnings -import os -from packaging.version import parse, Version - -import torch -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME -from setuptools import setup, find_packages -import subprocess - -# ninja build does not work unless include_dirs are abs path -this_dir = os.path.dirname(os.path.abspath(__file__)) - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version - - -def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) - torch_binary_version = parse(torch.version.cuda) - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_version != torch_binary_version): - raise RuntimeError( - "Cuda extensions are being compiled with a version of Cuda that does " - "not match the version used to compile Pytorch binaries. " - "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) - + "In some cases, a minor-version mismatch will not cause later errors: " - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk)." - ) - - -def raise_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - raise RuntimeError( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " - "only images whose names contain 'devel' will provide nvcc." - ) - - -def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.2"): - nvcc_threads = os.getenv("NVCC_THREADS") or "4" - return nvcc_extra_args + ["--threads", nvcc_threads] - return nvcc_extra_args - - -if not torch.cuda.is_available(): - # https://github.com/NVIDIA/apex/issues/486 - # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), - # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). - print( - "\nWarning: Torch did not find available GPUs on this system.\n", - "If your intention is to cross-compile, this is not an error.\n" - "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" - "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" - "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" - "If you wish to cross-compile for a single specific architecture,\n" - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', - ) - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.8"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" - elif bare_metal_version >= Version("11.1"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" - elif bare_metal_version == Version("11.0"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" - else: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" - - -print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) - -cmdclass = {} -ext_modules = [] - -# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h -# See https://github.com/pytorch/pytorch/pull/70650 -generator_flag = [] -torch_dir = torch.__path__[0] -if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): - generator_flag = ["-DOLD_GENERATOR_PATH"] - -raise_if_cuda_home_none("--xentropy") -# Check, if CUDA11 is installed for compute capability 8.0 -cc_flag = [] -_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) -if bare_metal_version < Version("11.0"): - raise RuntimeError("xentropy is only supported on CUDA 11 and above") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_70,code=sm_70") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_80,code=sm_80") -if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - -ext_modules.append( - CUDAExtension( - name="xentropy_cuda_lib", - sources=[ - "interface.cpp", - "xentropy_kernel.cu" - ], - extra_compile_args={ - "cxx": ["-O3"] + generator_flag, - "nvcc": append_nvcc_threads( - ["-O3"] - + generator_flag - + cc_flag - ), - }, - include_dirs=[this_dir], - ) -) - -setup( - name="xentropy_cuda_lib", - version="0.1", - description="Cross-entropy loss", - ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension} if ext_modules else {}, -) diff --git a/csrc/xentropy/xentropy_kernel.cu b/csrc/xentropy/xentropy_kernel.cu deleted file mode 100644 index 66aab0007ba..00000000000 --- a/csrc/xentropy/xentropy_kernel.cu +++ /dev/null @@ -1,758 +0,0 @@ -// Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/xentropy/xentropy_kernel.cu -// TD [2022-09-17]: We make it work for bfloat16, and add an option to do the backward inplace (to save memory). -/** - * From PyTorch: - * - * Copyright (c) 2016- Facebook, Inc (Adam Paszke) - * Copyright (c) 2014- Facebook, Inc (Soumith Chintala) - * Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) - * Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) - * Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) - * Copyright (c) 2011-2013 NYU (Clement Farabet) - * Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) - * Copyright (c) 2006 Idiap Research Institute (Samy Bengio) - * Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) - * - * From Caffe2: - * - * Copyright (c) 2016-present, Facebook Inc. All rights reserved. - * - * All contributions by Facebook: - * Copyright (c) 2016 Facebook Inc. - * - * All contributions by Google: - * Copyright (c) 2015 Google Inc. - * All rights reserved. - * - * All contributions by Yangqing Jia: - * Copyright (c) 2015 Yangqing Jia - * All rights reserved. - * - * All contributions from Caffe: - * Copyright(c) 2013, 2014, 2015, the respective contributors - * All rights reserved. - * - * All other contributions: - * Copyright(c) 2015, 2016 the respective contributors - * All rights reserved. - * - * Caffe2 uses a copyright model similar to Caffe: each contributor holds - * copyright over their contributions to Caffe2. The project versioning records - * all such contribution and copyright details. If a contributor wants to further - * mark their specific copyright on a particular contribution, they should - * indicate their copyright solely in the commit message of the change when it is - * committed. - * - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America - * and IDIAP Research Institute nor the names of its contributors may be - * used to endorse or promote products derived from this software without - * specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ -#include -#include -#include - -#include -#include - -// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h -// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \ - switch(TYPE) \ - { \ - case at::ScalarType::Float: \ - { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t_##LEVEL = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } -// #else -// #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \ -// switch(TYPE) \ -// { \ -// case at::ScalarType::Float: \ -// { \ -// using scalar_t_##LEVEL = float; \ -// __VA_ARGS__; \ -// break; \ -// } \ -// case at::ScalarType::Half: \ -// { \ -// using scalar_t_##LEVEL = at::Half; \ -// __VA_ARGS__; \ -// break; \ -// } \ -// default: \ -// AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ -// } -// #endif - -#define ALIGN_BYTES 16 - -using Tensor = at::Tensor; -using TensorList = at::TensorList; -using ScalarType = at::ScalarType; -using at::acc_type; - -template -struct LogSoftMaxForwardEpilogue { - __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum) - : logsum(max_input + std::log(sum)) {} - - __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_log_sum_exp) - : logsum(max_log_sum_exp) {} - - __device__ __forceinline__ OutT operator()(T input) const { - return static_cast(input - logsum); - } - - const AccumT logsum; -}; - -template -struct LogSoftMaxBackwardEpilogue { - __device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum) - : sum(sum) {} - - __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const { - return static_cast(gradOutput - std::exp(static_cast(output)) * sum); - } - - const AccumT sum; -}; - - - -const int max_threads = 1024; - -inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) { - uint64_t block_size = 1; - uint64_t max_block_size = std::min(dim_size / ILP, static_cast(max_threads)); - while (block_size < (max_block_size/2)) block_size *= 2; - // Launch at least a single warp - the kernel assumes that. - block_size = std::max(block_size, static_cast(32)); - return dim3(block_size); -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - - -//////////////////////////////////////////////////////////////////////////////// -// Regular kernel (fast when dim_size is large; requires inner_size == 1) -//////////////////////////////////////////////////////////////////////////////// - - -template -struct MaxFloat -{ - __device__ __forceinline__ AccumT operator()(AccumT max, T v) const { - return ::max(max, (AccumT)v); - } -}; - -template -struct AddFloat -{ - __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { - return sum + v; - } -}; - -template -struct SumExpFloat -{ - __device__ __forceinline__ SumExpFloat(AccumT v) - : max_k(v) {} - - __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { - return sum + std::exp(v - max_k); - } - - const AccumT max_k; -}; - -template class Reduction, typename AccumT> -__device__ __forceinline__ AccumT -blockReduce(AccumT* smem, AccumT val, - const Reduction& r, - AccumT defaultVal) -{ - // To avoid RaW races from chaining blockReduce calls together, we need a sync here - __syncthreads(); - - smem[threadIdx.x] = val; - - __syncthreads(); - - AccumT warpVal = defaultVal; - - // First warp will perform per-warp reductions for the remaining warps - uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1; - if (threadIdx.x < 32) { - int lane = threadIdx.x % 32; - if (lane < blockDim.x / 32) { -#pragma unroll - for (int i = 0; i < 32; ++i) { - warpVal = r(warpVal, smem[lane * 32 + i]); - } - __syncwarp(mask); - smem[lane] = warpVal; - } - } - - __syncthreads(); - - // First thread will perform a reduction of the above per-warp reductions - AccumT blockVal = defaultVal; - - if (threadIdx.x == 0) { - for (int i = 0; i < blockDim.x / 32; ++i) { - blockVal = r(blockVal, smem[i]); - } - smem[0] = blockVal; - } - - // Sync and broadcast - __syncthreads(); - return smem[0]; -} - -template class Reduction1, template class Reduction2, typename AccumT> -__device__ __forceinline__ void -blockReduce(AccumT* smem, - AccumT* reducVal1, - AccumT val1, - const Reduction1& r1, - AccumT defaultVal1, - AccumT* reducVal2, - AccumT val2, - const Reduction2& r2, - AccumT defaultVal2) -{ - // To avoid RaW races from chaining blockReduce calls together, we need a sync here - __syncthreads(); - - smem[threadIdx.x] = val1; - smem[blockDim.x + threadIdx.x] = val2; - - __syncthreads(); - - AccumT warpVal1 = defaultVal1; - AccumT warpVal2 = defaultVal2; - - // First warp will perform per-warp reductions for the remaining warps - uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1; - if (threadIdx.x < 32) { - int lane = threadIdx.x % 32; - if (lane < blockDim.x / 32) { -#pragma unroll - for (int i = 0; i < 32; ++i) { - warpVal1 = r1(warpVal1, smem[lane * 32 + i]); - warpVal2 = r2(warpVal2, smem[lane * 32 + i + blockDim.x]); - } - __syncwarp(mask); - smem[lane] = warpVal1; - smem[lane + blockDim.x] = warpVal2; - } - } - - __syncthreads(); - - // First thread will perform a reduction of the above per-warp reductions - AccumT blockVal1 = defaultVal1; - AccumT blockVal2 = defaultVal2; - - if (threadIdx.x == 0) { - for (int i = 0; i < blockDim.x / 32; ++i) { - blockVal1 = r1(blockVal1, smem[i]); - blockVal2 = r2(blockVal2, smem[i + blockDim.x]); - } - smem[0] = blockVal1; - smem[blockDim.x] = blockVal2; - } - - // Sync and broadcast - __syncthreads(); - *reducVal1 = smem[0]; - *reducVal2 = smem[blockDim.x]; - __syncthreads(); -} - -template class Reduction, int ILP, typename T, typename AccumT> -__device__ __forceinline__ AccumT -ilpReduce(int shift, - T* data, - int size, - const Reduction& r, - AccumT defaultVal) -{ - typedef typename std::aligned_storage::type LoadT; - AccumT threadVal = defaultVal; - int offset = threadIdx.x; - - // shift and do 1 - if(shift > 0){ - data -= shift; - size += shift; - if(threadIdx.x >= shift){ - threadVal = r(threadVal, data[offset]); - } - size -= blockDim.x; - data += blockDim.x; - } - int last = size % (ILP * blockDim.x); - - T v[ILP]; - LoadT* value = reinterpret_cast(&v); - - for (; offset * ILP < (size - last); offset += blockDim.x) { - *value = reinterpret_cast(data)[offset]; - - for (int j = 0; j < ILP; ++j) { - threadVal = r(threadVal, v[j]); - } - } - - offset = size - last + threadIdx.x; - // Epilogue - for (; offset < size; offset += blockDim.x) - threadVal = r(threadVal, data[offset]); - - return threadVal; -} - -template class Reduction1, template class Reduction2, int ILP, typename T, typename AccumT> -__device__ __forceinline__ void -ilpReduce(int shift, - T* data, - int size, - AccumT* reducVal1, - const Reduction1& r1, - AccumT defaultVal1, - AccumT* reducVal2, - const Reduction2& r2, - AccumT defaultVal2) -{ - typedef typename std::aligned_storage::type LoadT; - - AccumT threadVal1 = defaultVal1; - AccumT threadVal2 = defaultVal2; - int offset = threadIdx.x; - - // shift and do 1 - if(shift > 0){ - data -= shift; - size += shift; - if(threadIdx.x >= shift){ - threadVal1 = r1(threadVal1, data[offset]); - threadVal2 = r2(threadVal2, data[offset]); - } - size -= blockDim.x; - data += blockDim.x; - } - int last = size % (ILP * blockDim.x); - - T v[ILP]; - LoadT* value = reinterpret_cast(&v); - - for (; offset * ILP < (size - last); offset += blockDim.x) { - *value = reinterpret_cast(data)[offset]; - - for (int j = 0; j < ILP; ++j) { - threadVal1 = r1(threadVal1, v[j]); - threadVal2 = r2(threadVal2, v[j]); - } - } - - offset = size - last + threadIdx.x; - // Epilogue - for (; offset < size; offset += blockDim.x) { - threadVal1 = r1(threadVal1, data[offset]); - threadVal2 = r2(threadVal2, data[offset]); - } - - *reducVal1 = threadVal1; - *reducVal2 = threadVal2; -} - -template class Epilogue> -__global__ void -cunn_SoftMaxXEntropyForward( - accscalar_t *losses, - outscalar_t *max_log_sum_exp, - scalar_t *input, - int64_t *labels, - int64_t classes, - const float smoothing, - const int total_classes) -{ - extern __shared__ unsigned char smem[]; - auto sdata = reinterpret_cast(smem); - // forward pointers to batch[blockIdx.x] - // each block handles a sample in the mini-batch - input += blockIdx.x * classes; - //output += blockIdx.x * classes; - const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t); - - int64_t label = labels[blockIdx.x]; - - // find the max and sum - accscalar_t threadMax, threadSum, max_k, sum_k; - ilpReduce( - shift, input, classes, - &threadMax, MaxFloat(), - -at::numeric_limits::max(), - &threadSum, AddFloat(), - static_cast(0)); - - blockReduce( - sdata, - &max_k, threadMax, Max(), - -at::numeric_limits::max(), - &sum_k, threadSum, Add(), - static_cast(0)); - - accscalar_t threadExp = ilpReduce(shift, input, classes, SumExpFloat(max_k), static_cast(0)); - accscalar_t sumAll = blockReduce( - sdata, threadExp, Add(), static_cast(0)); - - Epilogue epilogue(max_k, sumAll); - - // calculate per element loss with label smoothing - // reserve max + log_sum_exp for bprop - if (threadIdx.x == 0) { - accscalar_t lse = max_k + std::log(sumAll); - accscalar_t log_prob = (label >= 0 && label < classes) ? epilogue(static_cast(input[label])) : 0.f; - losses[blockIdx.x] = (lse - sum_k / total_classes) * smoothing - log_prob * (1 - smoothing); - max_log_sum_exp[blockIdx.x] = lse; - } -} - -template -__device__ __forceinline__ void -apply(scalar_t *gradInput, - scalar_t *logits, - outscalar_t *max_log_sum_exp, - outscalar_t *gradOutput, - int64_t *labels, - const float smoothing, - int classes, - const int total_classes) -{ - accscalar_t smooth_positives = 1.0 - smoothing; - accscalar_t smooth_negatives = smoothing / total_classes; - accscalar_t tmpGradOutput = gradOutput[blockIdx.x]; - int64_t label = labels[blockIdx.x]; - accscalar_t coeff = max_log_sum_exp[blockIdx.x]; - - int offset = threadIdx.x; - int last = classes % (ILP * blockDim.x); - - for (; offset < classes - last; offset += blockDim.x * ILP) { - accscalar_t tmpLogits[ILP]; - -#pragma unroll - for (int j = 0; j < ILP; ++j) { - tmpLogits[j] = static_cast(logits[offset + j * blockDim.x]); - } - -#pragma unroll - for (int j = 0; j < ILP; ++j) - gradInput[offset + j * blockDim.x] = tmpGradOutput * ( - std::exp(tmpLogits[j] - coeff) - static_cast( - (offset + j * blockDim.x == label) ? 1 : 0) * - smooth_positives - smooth_negatives); - } - - for (; offset < classes; offset += blockDim.x) - gradInput[offset] = tmpGradOutput * (std::exp( - static_cast(logits[offset]) - coeff) - - static_cast((offset == label) ? 1 : 0) * - smooth_positives - smooth_negatives); -} - - -template -__device__ __forceinline__ void -aligned_apply(int shift, - scalar_t *gradInput, - scalar_t *logits, - outscalar_t *max_log_sum_exp, - outscalar_t *gradOutput, - int64_t *labels, - const float smoothing, - int classes, - const int total_classes) -{ - accscalar_t smooth_positives = 1.0 - smoothing; - accscalar_t smooth_negatives = smoothing / total_classes; - accscalar_t tmpGradOutput = gradOutput[blockIdx.x]; - int64_t label = labels[blockIdx.x]; - accscalar_t coeff = max_log_sum_exp[blockIdx.x]; - - int offset = threadIdx.x; - - // shift and do 1 - if(shift > 0){ - logits -= shift; - gradInput -= shift; - classes += shift; - if(threadIdx.x >= shift){ - gradInput[offset] = tmpGradOutput * (std::exp( - static_cast(logits[offset]) - coeff) - - static_cast(((offset - shift) == label) ? 1 : 0) * - smooth_positives - smooth_negatives); - } - classes -= blockDim.x; - gradInput += blockDim.x; - logits += blockDim.x; - shift -= blockDim.x; - } - - int last = classes % (ILP * blockDim.x); - - typedef typename std::aligned_storage::type LoadT; - // input - scalar_t v[ILP]; - LoadT* value = reinterpret_cast(&v); - // output - scalar_t r[ILP]; - LoadT* result = reinterpret_cast(&r); - - for (; offset * ILP < (classes - last); offset += blockDim.x) { - *value = reinterpret_cast(logits)[offset]; - -#pragma unroll - for (int j = 0; j < ILP; ++j) { - r[j] = tmpGradOutput * (std::exp( - static_cast(v[j]) - coeff) - - static_cast(((ILP * offset + j - shift) == label) ? 1 : 0) * - smooth_positives - smooth_negatives); - } - reinterpret_cast(gradInput)[offset] = *result; - } - - offset = classes - last + threadIdx.x; - for (; offset < classes; offset += blockDim.x) - gradInput[offset] = tmpGradOutput * (std::exp( - static_cast(logits[offset]) - coeff) - - static_cast(((offset - shift) == label) ? 1 : 0) * - smooth_positives - smooth_negatives); - -} - -template class Epilogue> -__global__ void -cunn_SoftMaxXEntropyBackward( - scalar_t *gradInput, - scalar_t *logits, - outscalar_t *max_log_sum_exp, - outscalar_t *gradOutput, - int64_t *labels, - const float smoothing, - int classes, - const int total_classes) -{ - gradInput += blockIdx.x * classes; - logits += blockIdx.x * classes; - - // Do vectorized load/store when input/output have same alignment - const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t); - const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t); - if (shift == shift_){ - aligned_apply(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes); - } - else { - apply(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes); - } - -} - -template class Epilogue> -std::vector host_softmax_xentropy( - const Tensor & input_, - const Tensor & labels_, - const float smoothing, - const int total_classes) { - // For tensor parallel cross entropy with smoothing, we want to pass in the total number - // of classes so that smoothing can be applied correctly. If total_classes=-1, use the - // last dimension of the input tensor. - AT_ASSERTM(labels_.scalar_type() == ScalarType::Long,"Label type should be CUDA Long"); - - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{input_.device()}; - - auto input = input_.contiguous(); - Tensor max_log_sum_exp = at::empty_like(labels_, input.options().dtype(ScalarType::Float)); - Tensor losses = at::empty_like(labels_, input_.options().dtype(ScalarType::Float)); - - static_assert(std::is_same, float>::value || - std::is_same, double>::value, - "accscalar_t for half should be float or double"); - AT_ASSERTM(input.dim() == 2, "Currently only 2 dim input supported"); - AT_ASSERTM(labels_.dim() == 1, "Labels should be 1 dimensional"); - AT_ASSERTM(input.size(0) == labels_.size(0), "Input and label should have same number of examples"); - AT_ASSERTM(input.numel() > 0, "Number of classes in input should not be 0"); - - const int64_t dim = 1; - int64_t outer_size = 1; - int64_t dim_size = input.size(dim); - int64_t inner_size = 1; - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - for (int64_t i = 0; i < dim; ++i) - outer_size *= input.size(i); - for (int64_t i = dim + 1; i < input.dim(); ++i) - inner_size *= input.size(i); - // This kernel spawns a block per each element in the batch. - // XXX: it assumes that inner_size == 1 - TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported"); - - dim3 grid(outer_size); - - using namespace at; - DISPATCH_FLOAT_AND_HALF_AND_BF16(input.scalar_type(), 0, "host_softmax_xentropy", - using accscalar_t = at::acc_type; - const int ILP = sizeof(float4)/sizeof(scalar_t_0); - dim3 block = SoftMax_getBlockSize(ILP, dim_size); - cunn_SoftMaxXEntropyForward - <<>>( - losses.data_ptr(), max_log_sum_exp.data_ptr(), - input.data_ptr(), labels_.data_ptr(), - dim_size, smoothing, total_classes <= 0 ? dim_size : total_classes - ); - ); - - C10_CUDA_CHECK(cudaGetLastError()); - - std::vector ret = {losses, max_log_sum_exp}; - return ret; -} - -template class Epilogue> -Tensor host_softmax_xentropy_backward( - const at::Tensor &grad_loss, - at::Tensor &logits_, - const at::Tensor &max_log_sum_exp, - const at::Tensor &labels, - const float smoothing, - bool inplace, - const int total_classes) { - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{grad_loss.device()}; - - const int64_t dim = 1; - Tensor gI = inplace ? logits_ : at::empty_like(logits_); - if (grad_loss.numel() == 0) { - return gI; - } - - auto grad = grad_loss.contiguous(); - auto logits = logits_.contiguous(); - - static_assert(std::is_same, float>::value || - std::is_same, double>::value, - "accscalar_t for half should be float or double"); - if (grad.dim() == 0) grad = grad.view(1); - - AT_ASSERTM(logits_.dim() == 2, "Currently only 2 dim input supported"); - AT_ASSERTM(labels.dim() == 1, "Labels should be 1 dimensional"); - AT_ASSERTM(logits_.numel() > 0, "Number of classes in input should not be 0"); - AT_ASSERTM(logits_.size(0) == labels.size(0), "Input and label should have same number of examples"); - AT_ASSERTM(labels.size(0) == grad.size(0), "Label and loss should have same number of examples"); - - int64_t outer_size = 1; - int64_t dim_size = logits.size(dim); - int64_t inner_size = 1; - for (int64_t i = 0; i < dim; ++i) - outer_size *= logits.size(i); - for (int64_t i = dim + 1; i < logits.dim(); ++i) - inner_size *= logits.size(i); - // See descriptions of kernels above. - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported"); - - dim3 grid(outer_size); - - DISPATCH_FLOAT_AND_HALF_AND_BF16(gI.scalar_type(), 0, "host_softmax_xentropy_backward", - using accscalar_t = acc_type; - const int ILP = sizeof(float4)/sizeof(scalar_t_0); - dim3 block = SoftMax_getBlockSize(ILP, dim_size); - cunn_SoftMaxXEntropyBackward - <<>>( - gI.data_ptr(), logits.data_ptr(), - max_log_sum_exp.data_ptr(), - grad.data_ptr(), labels.data_ptr(), - smoothing, dim_size, total_classes - ); - ); - - C10_CUDA_CHECK(cudaGetLastError()); - return gI; -} - -std::vector softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing, const int total_classes){ - return host_softmax_xentropy(input, labels, smoothing, total_classes); -} - -at::Tensor softmax_xentropy_backward_cuda( - const at::Tensor &grad_loss, - at::Tensor &logits, - const at::Tensor &max_log_sum_exp, - const at::Tensor &labels, - const float smoothing, - const bool inplace, - const int total_classes) { - AT_ASSERTM((grad_loss.scalar_type() == ScalarType::Float), "expected grad types to be at::Float"); - return host_softmax_xentropy_backward(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace, total_classes); -} diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index db131242dd4..4a8a7c33f46 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.7.4.post1" +__version__ = "2.8.3" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/flash_attn/cute/.flake8 b/flash_attn/cute/.flake8 new file mode 100644 index 00000000000..bae5b85c002 --- /dev/null +++ b/flash_attn/cute/.flake8 @@ -0,0 +1,4 @@ +[flake8] +max-line-length = 100 +# W503: line break before binary operator +ignore = E731, E741, F841, W503 diff --git a/flash_attn/cute/AUTHORS b/flash_attn/cute/AUTHORS new file mode 100644 index 00000000000..bc3991c676d --- /dev/null +++ b/flash_attn/cute/AUTHORS @@ -0,0 +1,5 @@ +Tri Dao, tri@tridao.me +Jay Shah +Ted Zadouri +Markus Hoehnerbach +Vijay Thakkar \ No newline at end of file diff --git a/flash_attn/cute/LICENSE b/flash_attn/cute/LICENSE new file mode 100644 index 00000000000..5860e4b33f3 --- /dev/null +++ b/flash_attn/cute/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/flash_attn/cute/README.md b/flash_attn/cute/README.md new file mode 100644 index 00000000000..e69de29bb2d diff --git a/flash_attn/cute/__init__.py b/flash_attn/cute/__init__.py new file mode 100644 index 00000000000..fbbfc14050e --- /dev/null +++ b/flash_attn/cute/__init__.py @@ -0,0 +1,21 @@ +"""Flash Attention CUTE (CUDA Template Engine) implementation.""" + +__version__ = "0.1.0" + +import cutlass.cute as cute + +from .interface import ( + flash_attn_func, + flash_attn_varlen_func, +) + +from flash_attn.cute.cute_dsl_utils import cute_compile_patched + +# Patch cute.compile to optionally dump SASS +cute.compile = cute_compile_patched + + +__all__ = [ + "flash_attn_func", + "flash_attn_varlen_func", +] diff --git a/flash_attn/cute/ampere_helpers.py b/flash_attn/cute/ampere_helpers.py new file mode 100644 index 00000000000..e3072d8ce85 --- /dev/null +++ b/flash_attn/cute/ampere_helpers.py @@ -0,0 +1,103 @@ +# Copyright (c) 2025, Tri Dao. +from typing import Type, Callable, Optional + +import cutlass +import cutlass.cute as cute + + +def get_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.ComposedLayout: + dtype_byte = cutlass.const_expr(dtype.width // 8) + bytes_per_row = cutlass.const_expr(k_dim * dtype_byte) + smem_k_block_size = ( + cutlass.const_expr( + 128 + if bytes_per_row % 128 == 0 + else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16)) + ) + // dtype_byte + ) + swizzle_bits = ( + 4 + if smem_k_block_size == 128 + else (3 if smem_k_block_size == 64 else (2 if smem_k_block_size == 32 else 1)) + ) + swizzle_base = 2 if dtype_byte == 4 else (3 if dtype_byte == 2 else 4) + return cute.make_composed_layout( + cute.make_swizzle(swizzle_bits, swizzle_base, swizzle_base), + 0, + cute.make_ordered_layout( + (8 if cutlass.const_expr(k_dim % 32 == 0) else 16, smem_k_block_size), order=(1, 0) + ), + ) + + +@cute.jit +def gemm( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + tCsA: cute.Tensor, + tCsB: cute.Tensor, + smem_thr_copy_A: cute.TiledCopy, + smem_thr_copy_B: cute.TiledCopy, + hook_fn: Optional[Callable] = None, + A_in_regs: cutlass.Constexpr[bool] = False, + B_in_regs: cutlass.Constexpr[bool] = False, + swap_AB: cutlass.Constexpr[bool] = False, +) -> None: + if cutlass.const_expr(swap_AB): + gemm( + tiled_mma, + acc, + tCrB, + tCrA, + tCsB, + tCsA, + smem_thr_copy_B, + smem_thr_copy_A, + hook_fn, + A_in_regs=B_in_regs, + B_in_regs=A_in_regs, + swap_AB=False, + ) + else: + tCrA_copy_view = smem_thr_copy_A.retile(tCrA) + tCrB_copy_view = smem_thr_copy_B.retile(tCrB) + if cutlass.const_expr(not A_in_regs): + cute.copy(smem_thr_copy_A, tCsA[None, None, 0], tCrA_copy_view[None, None, 0]) + if cutlass.const_expr(not B_in_regs): + cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0]) + for k in cutlass.range_constexpr(cute.size(tCsA.shape[2])): + if k < cute.size(tCsA.shape[2]) - 1: + if cutlass.const_expr(not A_in_regs): + cute.copy( + smem_thr_copy_A, tCsA[None, None, k + 1], tCrA_copy_view[None, None, k + 1] + ) + if cutlass.const_expr(not B_in_regs): + cute.copy( + smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1] + ) + cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) + if cutlass.const_expr(k == 0 and hook_fn is not None): + hook_fn() + + +@cute.jit +def gemm_rs( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + tCsB: cute.Tensor, + smem_thr_copy_B: cute.TiledCopy, + hook_fn: Optional[Callable] = None, +) -> None: + tCrB_copy_view = smem_thr_copy_B.retile(tCrB) + cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0]) + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + if cutlass.const_expr(k < cute.size(tCrA.shape[2]) - 1): + cute.copy(smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1]) + cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) + if cutlass.const_expr(k == 0 and hook_fn is not None): + hook_fn() diff --git a/flash_attn/cute/barrier.py b/flash_attn/cute/barrier.py new file mode 100644 index 00000000000..c999b180167 --- /dev/null +++ b/flash_attn/cute/barrier.py @@ -0,0 +1,71 @@ +import cutlass +import cutlass.cute as cute +from cutlass import Int32 +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm + + +@dsl_user_op +def ld_acquire(lock_ptr: cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32: + lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() + state = llvm.inline_asm( + T.i32(), + [lock_ptr_i64], + "ld.global.acquire.gpu.b32 $0, [$1];", + "=r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + return cutlass.Int32(state) + + +@dsl_user_op +def red_relaxed( + lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None +) -> None: + lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)], + "red.relaxed.gpu.global.add.s32 [$0], $1;", + "l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def red_release( + lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None +) -> None: + lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)], + "red.release.gpu.global.add.s32 [$0], $1;", + "l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def wait_eq(lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: Int32) -> None: + flag_ptr = lock_ptr + flag_offset + if thread_idx == 0: + read_val = Int32(0) + while read_val != val: + read_val = ld_acquire(flag_ptr) + + +@cute.jit +def arrive_inc( + lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: cutlass.Constexpr[Int32] +) -> None: + flag_ptr = lock_ptr + flag_offset + if thread_idx == 0: + red_release(flag_ptr, val) + # red_relaxed(flag_ptr, val) diff --git a/flash_attn/cute/benchmark.py b/flash_attn/cute/benchmark.py new file mode 100644 index 00000000000..9a7820e7b0c --- /dev/null +++ b/flash_attn/cute/benchmark.py @@ -0,0 +1,268 @@ +# Copyright (c) 2023, Tri Dao. +"""Useful functions for writing test code.""" + +import torch +import torch.utils.benchmark as benchmark + + +def benchmark_forward( + fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs +): + """Use Pytorch Benchmark on the forward pass of an arbitrary function.""" + if verbose: + print(desc, "- Forward pass") + + def amp_wrapper(*inputs, **kwinputs): + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + fn(*inputs, **kwinputs) + + t = benchmark.Timer( + stmt="fn_amp(*inputs, **kwinputs)", + globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + + +def benchmark_backward( + fn, + *inputs, + grad=None, + repeats=10, + desc="", + verbose=True, + amp=False, + amp_dtype=torch.float16, + **kwinputs, +): + """Use Pytorch Benchmark on the backward pass of an arbitrary function.""" + if verbose: + print(desc, "- Backward pass") + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + y = fn(*inputs, **kwinputs) + if type(y) is tuple: + y = y[0] + if grad is None: + grad = torch.randn_like(y) + else: + if grad.shape != y.shape: + raise RuntimeError("Grad shape does not match output shape") + + def f(*inputs, y, grad): + # Set .grad to None to avoid extra operation of gradient accumulation + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None + y.backward(grad, retain_graph=True) + + t = benchmark.Timer( + stmt="f(*inputs, y=y, grad=grad)", + globals={"f": f, "inputs": inputs, "y": y, "grad": grad}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + + +def benchmark_combined( + fn, + *inputs, + grad=None, + repeats=10, + desc="", + verbose=True, + amp=False, + amp_dtype=torch.float16, + **kwinputs, +): + """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" + if verbose: + print(desc, "- Forward + Backward pass") + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + y = fn(*inputs, **kwinputs) + if type(y) is tuple: + y = y[0] + if grad is None: + grad = torch.randn_like(y) + else: + if grad.shape != y.shape: + raise RuntimeError("Grad shape does not match output shape") + + def f(grad, *inputs, **kwinputs): + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + y = fn(*inputs, **kwinputs) + if type(y) is tuple: + y = y[0] + y.backward(grad, retain_graph=True) + + t = benchmark.Timer( + stmt="f(grad, *inputs, **kwinputs)", + globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + + +def benchmark_fwd_bwd( + fn, + *inputs, + grad=None, + repeats=10, + desc="", + verbose=True, + amp=False, + amp_dtype=torch.float16, + **kwinputs, +): + """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" + return ( + benchmark_forward( + fn, + *inputs, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + benchmark_backward( + fn, + *inputs, + grad=grad, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + ) + + +def benchmark_all( + fn, + *inputs, + grad=None, + repeats=10, + desc="", + verbose=True, + amp=False, + amp_dtype=torch.float16, + **kwinputs, +): + """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" + return ( + benchmark_forward( + fn, + *inputs, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + benchmark_backward( + fn, + *inputs, + grad=grad, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + benchmark_combined( + fn, + *inputs, + grad=grad, + repeats=repeats, + desc=desc, + verbose=verbose, + amp=amp, + amp_dtype=amp_dtype, + **kwinputs, + ), + ) + + +def pytorch_profiler( + fn, + *inputs, + trace_filename=None, + backward=False, + amp=False, + amp_dtype=torch.float16, + cpu=False, + verbose=True, + **kwinputs, +): + """Wrap benchmark functions in Pytorch profiler to see CUDA information.""" + if backward: + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + out = fn(*inputs, **kwinputs) + if type(out) is tuple: + out = out[0] + g = torch.randn_like(out) + for _ in range(30): # Warm up + if backward: + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + out = fn(*inputs, **kwinputs) + if type(out) is tuple: + out = out[0] + # Backward should be done outside autocast + if backward: + out.backward(g, retain_graph=True) + activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [ + torch.profiler.ProfilerActivity.CUDA + ] + with torch.profiler.profile( + activities=activities, + record_shapes=True, + # profile_memory=True, + with_stack=True, + ) as prof: + if backward: + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + out = fn(*inputs, **kwinputs) + if type(out) is tuple: + out = out[0] + if backward: + out.backward(g, retain_graph=True) + if verbose: + # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50)) + print(prof.key_averages().table(row_limit=50)) + if trace_filename is not None: + prof.export_chrome_trace(trace_filename) + + +def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + fn(*inputs, **kwinputs) + torch.cuda.synchronize() + mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000) + if verbose: + print(f"{desc} max memory: {mem}GB") + torch.cuda.empty_cache() + return mem diff --git a/flash_attn/cute/blackwell_helpers.py b/flash_attn/cute/blackwell_helpers.py new file mode 100644 index 00000000000..e2ff2ccc9ae --- /dev/null +++ b/flash_attn/cute/blackwell_helpers.py @@ -0,0 +1,753 @@ +# Copyright (c) 2025, Tri Dao. +from typing import Optional, Tuple + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Boolean, const_expr +from cutlass.cute.nvgpu import tcgen05 +from cutlass._mlir.dialects import llvm + +import flash_attn.cute.mma_sm100_desc as sm100_desc +from flash_attn.cute.utils import parse_swizzle_from_pointer + + +@cute.jit +def gemm_w_idx( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + zero_init: bool | Boolean = False, + swap_AB: bool = False, +) -> None: + if const_expr(swap_AB): + return gemm_w_idx( + tiled_mma, acc, tCrB, tCrA, B_idx, A_idx, zero_init=zero_init, swap_AB=False + ) + else: + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + mma_atom = cute.make_mma_atom(tiled_mma.op) + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) + cute.gemm(mma_atom, acc, rA[None, None, k], rB[None, None, k], acc) + + +@cute.jit +def gemm_ptx_w_idx( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + zero_init: bool | Boolean = False, + **kwargs, +) -> None: + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + sA_cur = None + if const_expr(sA is not None): + sA_cur = sA if const_expr(A_idx is None) else sA[None, None, None, A_idx] + sB_cur = sB if const_expr(B_idx is None) else sB[None, None, None, B_idx] + mma_atom = cute.make_mma_atom(tiled_mma.op) + acc_tmem_addr = acc.iterator.toint() + gemm_ptx_partial( + mma_atom.op, acc_tmem_addr, rA, rB, sA_cur, sB_cur, zero_init=zero_init, **kwargs + ) + + +@cute.jit +def gemm( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + zero_init: bool | Boolean = False, +) -> cute.TiledMma: + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + tiled_mma.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) + cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) + return tiled_mma + + +def i64_to_i32x2(i: int) -> Tuple[int, int]: + """Convert a 64-bit integer to a tuple of two 32-bit integers.""" + return i & 0xFFFF_FFFF, (i >> 32) & 0xFFFF_FFFF + + +@cute.jit +def gemm_ptx( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + zero_init: bool | Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else None + sB_layout = sB.layout + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + sA_swizzle = parse_swizzle_from_pointer(sA.iterator) + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + sB_swizzle = parse_swizzle_from_pointer(sB.iterator) + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32(smem_desc_base_a_lo) | sm100_desc.make_smem_desc_start_addr( + sA[None, None, 0].iterator + ) + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = Int32(smem_desc_base_b_lo) | sm100_desc.make_smem_desc_start_addr( + sB[None, None, 0].iterator + ) + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + if const_expr(not is_ts): + smem_desc_a_lo = smem_desc_start_a_lo + ( + (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 + ) + smem_desc_b_lo = smem_desc_start_b_lo + ( + (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 + ) + # with cute.arch.elect_one(): + # cute.printf("smem_desc_a_lo = {}, smem_desc_b_lo = {}", smem_desc_a_lo, smem_desc_b_lo) + # cute.printf("smem_desc_a_lo_correct = {}, smem_desc_b_lo_correct = {}", smem_desc_a_lo_correct, smem_desc_b_lo_correct) + with cute.arch.elect_one(): + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + smem_desc_a_lo.ir_value(), + smem_desc_b_lo.ir_value(), + Int32(not zero_init or k != 0).ir_value(), + ], + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + ".reg .b32 idesc;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b64 smem_desc_a, {{$1, {hex(smem_desc_a_hi)}}};\n\t" + f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t" + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + tCrA[None, None, k].iterator.toint().ir_value(), + smem_desc_b_lo.ir_value(), + Int32(not zero_init or k != 0).ir_value(), + ], + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b64 smem_desc_b;\n\t" + f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"tcgen05.mma.cta_group::1.kind::f16 [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t" + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_loop( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + zero_init: bool | Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else tCrA.layout + sB_layout = sB.layout + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + sA_swizzle = parse_swizzle_from_pointer(sA.iterator) + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + sB_swizzle = parse_swizzle_from_pointer(sB.iterator) + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + + if const_expr(not is_ts): + offset_a = [ + (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])) + ] + else: + offset_a = [ + cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])) + ] + offset_a_diff = [ + offset_a[k] - offset_a[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ] + offset_b = [ + (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 + for k in cutlass.range_constexpr(cute.size(tCrB.shape[2])) + ] + offset_b_diff = [ + offset_b[k] - offset_b[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrB.shape[2])) + ] + + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32( + smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) + ) + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = Int32( + smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) + ) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + "mov.b32 smem_desc_a_lo, $1;\n\t" + "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + Int32(smem_desc_start_b_lo).ir_value(), + Int32(not zero_init).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + "mov.b32 tmem_a, $1;\n\t" + "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_partial( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc_tmem_addr: Int32, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + mbar_ptr: Optional[cutlass.Pointer] = None, + mbar_phase: Optional[Int32] = None, + zero_init: bool | Boolean = False, + # sA_offset: Int32 = 0, + # acc_offset: Int32 = 0, + tA_addr: Optional[Int32] = None, +) -> None: + # acc_tmem_addr += acc_offset + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else tCrA.layout + sB_layout = sB.layout + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + sA_swizzle = parse_swizzle_from_pointer(sA.iterator) + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + sB_swizzle = parse_swizzle_from_pointer(sB.iterator) + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + + tCrA_layout = ( + tCrA.layout + if const_expr(not is_ts) + else cute.recast_layout(32, tCrA.element_type.width, tCrA.layout) + ) + offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(cute.size(tCrA.shape[2]))] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] + offset_b = [cute.crd2idx((0, 0, k), tCrB.layout) for k in range(cute.size(tCrB.shape[2]))] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] + + if const_expr(not is_ts): + smem_desc_start_a_lo = Int32( + smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) + ) + # ) + sA_offset + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = Int32( + smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) + ) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM" + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" + "mov.b32 smem_desc_a_lo_start, $0;\n\t" + "mov.b32 smem_desc_b_lo_start, $1;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + # "r,r,r", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + # For TS gemm, somehow tCrA.iterator.toint() returns 0 no matter what, so we need to + # explicitly pass in the tA_addr for correctness. + tA_addr = tCrA[None, None, 0].iterator.toint() if tA_addr is None else tA_addr + input_args = [ + # Int32(cute.arch.make_warp_uniform(tCrA[None, None, 0].iterator.toint())).ir_value(), + Int32(cute.arch.make_warp_uniform(tA_addr)).ir_value(), + Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(not zero_init).ir_value(), + Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), + ] + if const_expr(mbar_ptr is not None): + assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None" + input_args.append(mbar_ptr.toint().ir_value()) + input_args.append(Int32(mbar_phase).ir_value()) + mbar_wait_str = ( + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + ) + else: + mbar_wait_str = "" + llvm.inline_asm( + None, + # [ + # # acc.iterator.toint().ir_value(), + # Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + # Int32(smem_desc_start_b_lo).ir_value(), + # Int32(not zero_init).ir_value(), + # ], + input_args, + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo_start;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_acc, $3;\n\t" + f"mov.b32 tmem_a, $0;\n\t" + f"mov.b32 smem_desc_b_lo_start, $1;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range( + 1, + cute.size(tCrA.shape[2]) + if const_expr(mbar_ptr is None) + else cute.size(tCrA.shape[2]) // 4 * 3, + ) + ) + + mbar_wait_str + + ( + "".join( + ( + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range(cute.size(tCrA.shape[2]) // 4 * 3, cute.size(tCrA.shape[2])) + ) + if const_expr(mbar_ptr is not None) + else "" + ) + + "}\n", + "r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_partial1( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc_tmem_addr: cutlass.Constexpr[int], + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA_base_addr_for_desc: Int32, + sA_addr_offset_for_desc: cutlass.Constexpr[int], + sA_stage: Int32, + sB_base_addr_for_desc: Int32, + sB_addr_offset_for_desc: cutlass.Constexpr[int], + sB_stage: Int32, + sA_layout: Optional[cute.Layout], + sB_layout: Optional[cute.Layout], + sA_swizzle: Optional[cute.Swizzle], + sB_swizzle: cute.Swizzle, + zero_init: bool | Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if const_expr(not is_ts): + assert sA_layout is not None, "sA_layout must be provided when a_src is not TMEM" + assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" + idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + if const_expr(not is_ts): + smem_desc_base_a: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K + if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b: int = const_expr( + sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K + if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) + else sm100_desc.Major.MN, + ) + ) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = const_expr(smem_desc_b_hi) + mask = [Int32(0)] * 4 + + if const_expr(not is_ts): + offset_a = [ + (cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4 + for k in range(cute.size(tCrA.shape[2])) + ] + else: + offset_a = [ + cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 + for k in range(cute.size(tCrA.shape[2])) + ] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] + offset_b = [ + (cute.crd2idx((0, 0, k), sB_layout) * op.b_dtype.width // 8) >> 4 + for k in range(cute.size(tCrB.shape[2])) + ] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] + + if const_expr(not is_ts): + # smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)) + smem_desc_start_a_lo = const_expr(smem_desc_base_a_lo) + else: + smem_desc_start_a_lo = None + # smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)) + smem_desc_start_b_lo = const_expr(smem_desc_base_b_lo) + pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" + if const_expr(not is_ts): + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + # Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + Int32(sA_base_addr_for_desc).ir_value(), + Int32(sA_stage).ir_value(), + # Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + Int32(sB_base_addr_for_desc).ir_value(), + Int32(sB_stage).ir_value(), + Int32(not zero_init).ir_value(), + mask[0].ir_value(), + mask[1].ir_value(), + mask[2].ir_value(), + mask[3].ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + # "mov.b32 smem_desc_a_lo, $0;\n\t" + # f"add.u32 smem_desc_a_lo, $0, {hex(smem_desc_start_a_lo)};\n\t" + f"mad.lo.u32 smem_desc_a_lo, $1, {hex(sA_addr_offset_for_desc)}, $0;\n\t" + # "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mad.lo.u32 smem_desc_b_lo, $3, {hex(sB_addr_offset_for_desc)}, $2;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $4, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t" + + "".join( + ( + f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + Int32(smem_desc_start_b_lo).ir_value(), + Int32(not zero_init).ir_value(), + mask[0].ir_value(), + mask[1].ir_value(), + mask[2].ir_value(), + mask[3].ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_a, $1;\n\t" + f"mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t" + + "".join( + ( + f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) diff --git a/flash_attn/cute/block_info.py b/flash_attn/cute/block_info.py new file mode 100644 index 00000000000..be13e70f892 --- /dev/null +++ b/flash_attn/cute/block_info.py @@ -0,0 +1,108 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +from typing import Tuple, Optional +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, const_expr + +from flash_attn.cute.seqlen_info import SeqlenInfoQK + + +@dataclass(frozen=True) +class BlockInfo: + tile_m: cutlass.Constexpr[int] + tile_n: cutlass.Constexpr[int] + is_causal: cutlass.Constexpr[bool] + is_local: cutlass.Constexpr[bool] = False + is_split_kv: cutlass.Constexpr[bool] = False + window_size_left: Optional[Int32] = None + window_size_right: Optional[Int32] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + + @cute.jit + def get_n_block_min_max( + self, + seqlen_info: SeqlenInfoQK, + m_block: Int32, + split_idx: cutlass.Int32 = 0, + num_splits: cutlass.Int32 = 1, + ) -> Tuple[Int32, Int32]: + n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n) + if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)): + m_idx_max = (m_block + 1) * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) + n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q + n_idx_right = n_idx if const_expr(self.is_causal) else n_idx + self.window_size_right + n_block_max = min(n_block_max, cute.ceil_div(n_idx_right, self.tile_n)) + n_block_min = 0 + if const_expr(self.is_local and self.window_size_left is not None): + m_idx_min = m_block * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa + n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q + n_idx_left = n_idx - self.window_size_left + n_block_min = cutlass.max(n_idx_left // self.tile_n, 0) + if cutlass.const_expr(self.is_split_kv): + num_n_blocks_per_split = ( + cutlass.Int32(0) + if n_block_max <= n_block_min + else (n_block_max - n_block_min + num_splits - 1) // num_splits + ) + n_block_min = n_block_min + split_idx * num_n_blocks_per_split + n_block_max = cutlass.min(n_block_min + num_n_blocks_per_split, n_block_max) + return n_block_min, n_block_max + + @cute.jit + def get_m_block_min_max(self, seqlen_info: SeqlenInfoQK, n_block: Int32) -> Tuple[Int32, Int32]: + m_block_max = cute.ceil_div(seqlen_info.seqlen_q, self.tile_m) + m_block_min = 0 + if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)): + n_idx_min = n_block * self.tile_n + m_idx = n_idx_min + seqlen_info.seqlen_q - seqlen_info.seqlen_k + m_idx_right = m_idx if const_expr(self.is_causal) else m_idx - self.window_size_right + m_block_min = max(m_block_min, m_idx_right // self.tile_m) + if const_expr(self.is_local and self.window_size_left is not None): + n_idx_max = (n_block + 1) * self.tile_n + m_idx = n_idx_max + seqlen_info.seqlen_q - seqlen_info.seqlen_k + m_idx_left = m_idx + self.window_size_left + m_block_max = min(m_block_max, cute.ceil_div(m_idx_left, self.tile_m)) + return m_block_min, m_block_max + + @cute.jit + def get_n_block_min_causal_local_mask( + self, + seqlen_info: SeqlenInfoQK, + m_block: Int32, + n_block_min: Int32, + ) -> Int32: + """If we have separate iterations with causal or local masking at the start, where do we stop""" + m_idx_min = m_block * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa + n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q + n_idx_right = ( + n_idx + if const_expr(not self.is_local or self.window_size_right is None) + else n_idx + self.window_size_right + ) + return cutlass.max(n_block_min, n_idx_right // self.tile_n) + + @cute.jit + def get_n_block_min_before_local_mask( + self, + seqlen_info: SeqlenInfoQK, + m_block: Int32, + n_block_min: Int32, + ) -> Int32: + """If we have separate iterations with local masking at the end, where do we stop the non-masked iterations""" + if const_expr(not self.is_local or self.window_size_left is None): + return n_block_min + else: + m_idx_max = (m_block + 1) * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) + n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q + n_idx_left = n_idx - self.window_size_left + return cutlass.max(n_block_min, cute.ceil_div(n_idx_left, self.tile_n)) diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py new file mode 100644 index 00000000000..fe1c4cea812 --- /dev/null +++ b/flash_attn/cute/block_sparse_utils.py @@ -0,0 +1,1451 @@ +""" +Block-sparse runtime utilities for CUTE DSL kernels. + +This module contains runtime execution functions for block-sparse attention kernels. +These utilities are used by CUTE DSL kernels to produce and consume block-sparse loads. +""" + +from typing import Callable, Optional +from functools import partial +import math +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr + +# Import data structures from block_sparsity +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute import utils +from flash_attn.cute import copy_utils +from flash_attn.cute.named_barrier import NamedBarrierBwd + + +@cute.jit +def load_block_list( + block_indices: cute.Tensor, + block_count, + load_q_with_first: cutlass.Constexpr, + first_block_preloaded: cutlass.Constexpr, + kv_producer_state, + load_Q, + load_K, + load_V, + pipeline_k, + pipeline_v, + use_tma_q: cutlass.Constexpr, + tma_q_bytes: cutlass.Constexpr, + intra_wg_overlap: cutlass.Constexpr, +): + """Iterate over the sparse blocks and load K, V (and Q) into the pipeline. + for the intra_wg_overlap case, we overlap the loads of K and V. And this + means we need to pipeline the last V load from the partial block case, + with the loads for the full blocks. Set first_block_preloaded when the + caller has already issued the first K load for the list. + + Note: + we iterate along the block_n indices in reverse. + + Returns: + Updated kv_producer_state after processing the block list. + + """ + if block_count > 0: + if const_expr(not intra_wg_overlap): + # Peel first iteration: the first block may need to load Q alongside K, + # Parameters are already Constexpr, so no need to wrap in const_expr() + n_block_first = block_indices[block_count - 1] + extra_tx = tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0 + pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx) + + if const_expr(load_q_with_first and use_tma_q): + load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + + load_K(src_idx=n_block_first, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block_first, producer_state=kv_producer_state) + kv_producer_state.advance() + + for offset in cutlass.range(1, block_count): + n_block = block_indices[block_count - 1 - offset] + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block, producer_state=kv_producer_state) + kv_producer_state.advance() + else: + n_block_first = block_indices[block_count - 1] + if const_expr(not first_block_preloaded): + extra_tx = ( + tma_q_bytes if const_expr(load_q_with_first) and const_expr(use_tma_q) else 0 + ) + pipeline_k.producer_acquire(kv_producer_state, extra_tx_count=extra_tx) + + if const_expr(load_q_with_first and use_tma_q): + load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + + load_K(src_idx=n_block_first, producer_state=kv_producer_state) + + for idx in cutlass.range(block_count - 1, unroll=1): + n_block_prev = block_indices[block_count - 1 - idx] + n_block = block_indices[block_count - 2 - idx] + kv_producer_state_prev = kv_producer_state.clone() + kv_producer_state.advance() + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state_prev) + load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev) + + return kv_producer_state + + +@cute.jit +def finish_overlap_v_load( + block_indices: cute.Tensor, + block_count, + load_V, + pipeline_v, + kv_producer_state, +): + """Load the final V block after overlapped K/V loads.""" + if block_count > 0: + n_block_last = block_indices[0] + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block_last, producer_state=kv_producer_state) + kv_producer_state.advance() + + return kv_producer_state + + +@cute.jit +def sparse_tensor_m_block( + m_block, + qhead_per_kvhead: cutlass.Constexpr[int], +): + """Map packed m_block indices to block-sparse tensor indices.""" + if const_expr(qhead_per_kvhead != 1): + return m_block // qhead_per_kvhead + return m_block + + +@cute.jit +def produce_block_sparse_loads( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + m_block, + kv_producer_state, + load_Q, + load_K, + load_V, + pipeline_k, + pipeline_v, + use_tma_q: cutlass.Constexpr, + tma_q_bytes: cutlass.Constexpr, + intra_wg_overlap: cutlass.Constexpr, + qhead_per_kvhead: cutlass.Constexpr[int] = 1, +): + """Iterate over the mask and full block lists for a single tile. + + The masked (partial) list may leave the last V load pending when intra-warp-group + overlap is enabled. The first full block must consume that pending V while + issuing its own K load on the next pipeline stage. + + In the intra-wg-overlap path, the last masked block leaves its V copy in flight + while we advance the producer state to start the next full K. Either the full list + overlaps that pending V load, or, if no full blocks exist, we explicitly drain it. + + Args: + qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and + must be converted to unpacked for sparse tensor indexing. + """ + + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + + m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead) + + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] + + if const_expr(full_block_cnt is not None): + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] + else: + curr_full_block_cnt = Int32(0) + curr_full_block_idx = None + + mask_empty = curr_mask_block_cnt == 0 + full_empty = curr_full_block_cnt == 0 + + if mask_empty: + # No masked blocks: the full list owns the initial Q+K load. + kv_producer_state = load_block_list( + curr_full_block_idx, + curr_full_block_cnt, + load_q_with_first=True, + first_block_preloaded=False, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_k=pipeline_k, + pipeline_v=pipeline_v, + use_tma_q=use_tma_q, + tma_q_bytes=tma_q_bytes, + intra_wg_overlap=intra_wg_overlap, + ) + + if const_expr(intra_wg_overlap) and curr_full_block_cnt > 0: + kv_producer_state = finish_overlap_v_load( + curr_full_block_idx, + curr_full_block_cnt, + load_V, + pipeline_v, + kv_producer_state, + ) + else: + # Masked blocks present: load Q together with the first masked K so consumers can + # start immediately. When overlap is disabled this fully drains the list. + kv_producer_state = load_block_list( + curr_mask_block_idx, + curr_mask_block_cnt, + load_q_with_first=True, + first_block_preloaded=False, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_k=pipeline_k, + pipeline_v=pipeline_v, + use_tma_q=use_tma_q, + tma_q_bytes=tma_q_bytes, + intra_wg_overlap=intra_wg_overlap, + ) + + if full_empty: + if const_expr(intra_wg_overlap): + kv_producer_state = finish_overlap_v_load( + curr_mask_block_idx, + curr_mask_block_cnt, + load_V, + pipeline_v, + kv_producer_state, + ) + else: + if const_expr(intra_wg_overlap): + # Bridge the masked list to the full list by overlapping the pending masked V + # with the first full K load. + n_block_mask_last = curr_mask_block_idx[0] + n_block_full_first = curr_full_block_idx[curr_full_block_cnt - 1] + kv_producer_state_prev = kv_producer_state.clone() + kv_producer_state.advance() + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block_full_first, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state_prev) + load_V(src_idx=n_block_mask_last, producer_state=kv_producer_state_prev) + + kv_producer_state = load_block_list( + curr_full_block_idx, + curr_full_block_cnt, + load_q_with_first=False, + first_block_preloaded=True, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_k=pipeline_k, + pipeline_v=pipeline_v, + use_tma_q=use_tma_q, + tma_q_bytes=tma_q_bytes, + intra_wg_overlap=intra_wg_overlap, + ) + + kv_producer_state = finish_overlap_v_load( + curr_full_block_idx, + curr_full_block_cnt, + load_V, + pipeline_v, + kv_producer_state, + ) + else: + # Non-overlap path with both lists: run the full list normally (skipping the Q + # reload because the masked list already issued it). + kv_producer_state = load_block_list( + curr_full_block_idx, + curr_full_block_cnt, + load_q_with_first=False, + first_block_preloaded=False, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_k=pipeline_k, + pipeline_v=pipeline_v, + use_tma_q=use_tma_q, + tma_q_bytes=tma_q_bytes, + intra_wg_overlap=intra_wg_overlap, + ) + + return kv_producer_state + + +@cute.jit +def consume_block_sparse_loads( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + m_block, + seqlen, + kv_consumer_state, + mma_pv_fn, + mma_one_n_block, + process_first_half_block, + process_last_half_block, + mask_fn, + score_mod_fn, + O_should_accumulate, + mask_mod, + fastdiv_mods, + intra_wg_overlap: cutlass.Constexpr, + warp_scheduler_barrier_sync: Callable, + warp_scheduler_barrier_arrive: Callable, + qhead_per_kvhead: cutlass.Constexpr[int] = 1, +): + """Consume the mask and full block lists for a single tile on the consumer side. + + Mirrors `produce_block_sparse_loads` so that the consumer pipeline uses + the same sparse tensor indexing. + + Args: + qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and + must be converted to unpacked for sparse tensor indexing. + """ + + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + + m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead) + + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] + + processed_any = curr_mask_block_cnt + curr_full_block_cnt > 0 + + if const_expr(not intra_wg_overlap): + if curr_mask_block_cnt > 0: + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] + warp_scheduler_barrier_sync() + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=mask_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial( + mask_fn, + mask_mod=mask_mod, + mask_seqlen=True, + fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None, + ), + is_first_n_block=True, + ) + O_should_accumulate = True + for i in cutlass.range(1, curr_mask_block_cnt): + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=mask_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False), + is_first_n_block=False, + ) + O_should_accumulate = True + if curr_full_block_cnt == 0: + warp_scheduler_barrier_arrive() + + if curr_full_block_cnt > 0: + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] + if curr_mask_block_cnt == 0: + warp_scheduler_barrier_sync() + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_seqlen=True), + is_first_n_block=True, + ) + O_should_accumulate = True + for i in cutlass.range(1, curr_full_block_cnt): + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_seqlen=False), + is_first_n_block=False, + ) + O_should_accumulate = True + else: + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), + is_first_n_block=False, + ) + O_should_accumulate = True + for i in cutlass.range(1, curr_full_block_cnt): + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False), + is_first_n_block=False, + ) + O_should_accumulate = True + warp_scheduler_barrier_arrive() + else: + if curr_mask_block_cnt > 0: + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] + kv_consumer_state = process_first_half_block( + n_block=mask_n_block, + seqlen=seqlen, + kv_consumer_state=kv_consumer_state, + mask_fn=partial( + mask_fn, + mask_mod=mask_mod, + mask_seqlen=True, + fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None, + ), + score_mod_fn=score_mod_fn, + is_first_block=True, + ) + for i in cutlass.range(1, curr_mask_block_cnt): + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=mask_n_block, + seqlen=seqlen, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False), + ) + O_should_accumulate = True + + if curr_full_block_cnt > 0: + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] + if curr_mask_block_cnt == 0: + kv_consumer_state = process_first_half_block( + n_block=full_n_block, + seqlen=seqlen, + kv_consumer_state=kv_consumer_state, + mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), + score_mod_fn=score_mod_fn, + is_first_block=True, + ) + else: + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + seqlen=seqlen, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), + ) + O_should_accumulate = True + for i in cutlass.range(1, curr_full_block_cnt): + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=full_n_block, + seqlen=seqlen, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False), + ) + O_should_accumulate = True + + if curr_mask_block_cnt + curr_full_block_cnt > 0: + kv_consumer_state = process_last_half_block( + kv_consumer_state=kv_consumer_state, + zero_init=not O_should_accumulate, + ) + O_should_accumulate = True + + return kv_consumer_state, O_should_accumulate, processed_any + + +@cute.jit +def load_block_list_sm100( + block_indices: cute.Tensor, + block_count, + load_q_with_first: cutlass.Constexpr, + m_block, + q_stage: cutlass.Constexpr, + kv_producer_state, + load_Q, + load_K, + load_V, + pipeline_kv, +): + """SM100 version of load_block_list (no intra_wg_overlap, no extra_tx_count).""" + if block_count > 0: + # First iteration: load Q alongside K if requested + n_block_first = block_indices[block_count - 1] + + if const_expr(load_q_with_first): + # SM100 loads Q0 and optionally Q1 + load_Q(block=q_stage * m_block + 0, stage=0) + if const_expr(q_stage == 2): + load_Q(block=q_stage * m_block + 1, stage=1) + + # SM100 doesn't use producer_acquire for pipeline_kv in load path + # The pipeline barriers are handled inside load_KV + load_K(block=n_block_first, producer_state=kv_producer_state, page_idx=None) + kv_producer_state.advance() + load_V(block=n_block_first, producer_state=kv_producer_state, page_idx=None) + kv_producer_state.advance() + + # Remaining blocks + for offset in cutlass.range(1, block_count): + n_block = block_indices[block_count - 1 - offset] + load_K(block=n_block, producer_state=kv_producer_state, page_idx=None) + kv_producer_state.advance() + load_V(block=n_block, producer_state=kv_producer_state, page_idx=None) + kv_producer_state.advance() + + return kv_producer_state + + +# SM100-specific tile processor using SM100 helpers +@cute.jit +def produce_block_sparse_loads_sm100( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + m_block, + kv_producer_state, + load_Q, + load_K, + load_V, + pipeline_kv, + q_stage: cutlass.Constexpr, + q_producer_phase: Int32, + qhead_per_kvhead: cutlass.Constexpr, +): + """SM100 entry point for sparse block iteration. + + SM100 uses PipelineTmaUmma which doesn't support extra_tx_count, so we use + simplified block processing that just calls producer_acquire without extras. + + Args: + m_block: which tile of m we are processing + qhead_per_kvhead: Constexpr pack factor + """ + # NB: Compute unpacked index for sparse tensor access + if const_expr(qhead_per_kvhead != 1): + m_block_sparse = m_block // qhead_per_kvhead + else: + m_block_sparse = m_block + + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] + + if const_expr(full_block_cnt is not None): + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] + else: + curr_full_block_cnt = Int32(0) + curr_full_block_idx = None + + mask_empty = curr_mask_block_cnt == 0 + full_empty = curr_full_block_cnt == 0 + + q_phase_flipped = False + + if mask_empty: + # No masked blocks: process full list with Q loading + kv_producer_state = load_block_list_sm100( + curr_full_block_idx, + curr_full_block_cnt, + load_q_with_first=True, + m_block=m_block, + q_stage=q_stage, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_kv=pipeline_kv, + ) + q_phase_flipped = not full_empty + else: + # Process masked blocks with Q loading + kv_producer_state = load_block_list_sm100( + curr_mask_block_idx, + curr_mask_block_cnt, + load_q_with_first=True, + m_block=m_block, + q_stage=q_stage, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_kv=pipeline_kv, + ) + q_phase_flipped = True + + if not full_empty: + # Process full blocks without Q loading + kv_producer_state = load_block_list_sm100( + curr_full_block_idx, + curr_full_block_cnt, + load_q_with_first=False, + m_block=m_block, + q_stage=q_stage, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_kv=pipeline_kv, + ) + + if q_phase_flipped: + q_producer_phase ^= 1 + + return kv_producer_state, q_producer_phase + + +@cute.jit +def get_total_block_count( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + m_block, + qhead_per_kvhead: cutlass.Constexpr, +): + # NB: Convert packed m_block to unpacked for sparse tensor indexing + if const_expr(qhead_per_kvhead != 1): + m_block_sparse = m_block // qhead_per_kvhead + else: + m_block_sparse = m_block + + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + if const_expr(full_block_cnt is not None): + return ( + mask_block_cnt[batch_idx, head_idx, m_block_sparse] + + full_block_cnt[batch_idx, head_idx, m_block_sparse] + ) + else: + return mask_block_cnt[batch_idx, head_idx, m_block_sparse] + + +@cute.jit +def handle_block_sparse_empty_tile_correction_sm100( + tidx: Int32, + q_stage: cutlass.Constexpr, + m_block_size: cutlass.Constexpr, + qhead_per_kvhead, + pack_gqa: cutlass.Constexpr, + is_split_kv: cutlass.Constexpr, + learnable_sink, + mLSE, + seqlen, + m_block: Int32, + head_idx: Int32, + batch_idx: Int32, + split_idx: Int32, + sScale: cute.Tensor, + stats: list, + correction_epilogue: Callable, + thr_mma_pv: cute.core.ThrMma, + tOtOs: tuple[cute.Tensor], + sO: cute.Tensor, + mbar_ptr, + mbar_softmax_corr_full_offset: Int32, + mbar_softmax_corr_empty_offset: Int32, + mbar_P_full_O_rescaled_offset: Int32, + mbar_P_full_2_offset: Int32, + mbar_corr_epi_full_offset: Int32, + mbar_corr_epi_empty_offset: Int32, + softmax_corr_consumer_phase: Int32, + o_corr_consumer_phase: Int32, + corr_epi_producer_phase: Int32, + softmax_scale_log2: Float32, + mO_cur: Optional[cute.Tensor] = None, + gO: Optional[cute.Tensor] = None, + gmem_tiled_copy_O: Optional[cute.TiledCopy] = None, +): + """Handle the block-sparse case where a tile is fully masked: + * zero staged results + * seed stats + * satisfy the usual barrier protocol so downstream warps continue to make progress. + """ + LOG2_E = Float32(math.log2(math.e)) + + for stage in cutlass.range_constexpr(q_stage): + row_sum_value = Float32(1.0) + row_max_value = ( + -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None + ) + if const_expr(learnable_sink is not None): + sink_val = -Float32.inf + if const_expr(not pack_gqa): + sink_val = Float32(learnable_sink[head_idx]) + elif tidx < m_block_size: + q_head_idx = ( + (q_stage * m_block + stage) * m_block_size + tidx + ) % qhead_per_kvhead + head_idx * qhead_per_kvhead + sink_val = Float32(learnable_sink[q_head_idx]) + if sink_val != -Float32.inf and (const_expr(not is_split_kv) or split_idx == 0): + if row_max_value == -Float32.inf: + row_max_value = sink_val * (LOG2_E / softmax_scale_log2) + row_sum_value = Float32(1.0) + else: + row_sum_value = row_sum_value + utils.exp2f( + sink_val * LOG2_E - row_max_value * softmax_scale_log2 + ) + if tidx < m_block_size: + scale_row_idx = tidx + stage * m_block_size + sScale[scale_row_idx] = row_sum_value + if const_expr(mLSE is not None or learnable_sink is not None): + sScale[scale_row_idx + m_block_size * 2] = row_max_value + acc_flag = row_sum_value == Float32(0.0) or row_sum_value != row_sum_value + stats[stage] = (row_sum_value, row_max_value, acc_flag) + + cute.arch.mbarrier_wait( + mbar_ptr + mbar_softmax_corr_full_offset + stage, + softmax_corr_consumer_phase, + ) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_empty_offset + stage) + + if const_expr(gmem_tiled_copy_O is None): + cute.arch.mbarrier_wait( + mbar_ptr + mbar_corr_epi_empty_offset + stage, + corr_epi_producer_phase, + ) + correction_epilogue( + thr_mma_pv, + tOtOs[stage], + tidx, + stage, + m_block, + seqlen.seqlen_q, + Float32(0.0), # zero scale ensures empty tile writes zeros into staged outputs + sO[None, None, stage], + mO_cur, + gO, + gmem_tiled_copy_O, + ) + if const_expr(gmem_tiled_copy_O is None): + cute.arch.mbarrier_arrive(mbar_ptr + mbar_corr_epi_full_offset + stage) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_O_rescaled_offset + stage) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_2_offset + stage) + + softmax_corr_consumer_phase ^= 1 + o_corr_consumer_phase ^= 1 + corr_epi_producer_phase ^= 1 + + return ( + softmax_corr_consumer_phase, + o_corr_consumer_phase, + corr_epi_producer_phase, + ) + + +@cute.jit +def softmax_block_sparse_sm100( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + m_block, + softmax_step: Callable, + mask_fn: Callable, + mask_fn_none: Callable, + mma_si_consumer_phase: Int32, + si_corr_producer_phase: Int32, + s0_s1_sequence_phase: Int32, + mbar_ptr, + mbar_softmax_corr_full_offset: Int32, + mbar_softmax_corr_empty_offset: Int32, + mbar_P_full_O_rescaled_offset: Int32, + mbar_P_full_2_offset: Int32, + q_stage: cutlass.Constexpr, + stage_idx: Int32, + check_m_boundary: bool, + qhead_per_kvhead: cutlass.Constexpr, +): + # Convert packed m_block to unpacked for sparse tensor indexing + if const_expr(qhead_per_kvhead != 1): + m_block_sparse = m_block // qhead_per_kvhead + else: + m_block_sparse = m_block + + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] + + if const_expr(full_block_cnt is not None): + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] + else: + curr_full_block_cnt = Int32(0) + curr_full_block_idx = None + + total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt + + if total_block_cnt == 0: + cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_full_offset + stage_idx) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_O_rescaled_offset + stage_idx) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_2_offset + stage_idx) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_empty_offset + stage_idx) + else: + if curr_mask_block_cnt > 0: + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + ) = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + mask_n_block, + is_first=True, + mask_fn=partial(mask_fn, mask_seqlen=True, check_q_boundary=check_m_boundary), + ) + for i in cutlass.range(1, curr_mask_block_cnt): + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + ) = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + mask_n_block, + mask_fn=partial(mask_fn, mask_seqlen=False, check_q_boundary=check_m_boundary), + ) + + if curr_full_block_cnt > 0: + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] + if curr_mask_block_cnt == 0: + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + ) = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + full_n_block, + is_first=True, + mask_fn=partial( + mask_fn_none, mask_seqlen=True, check_q_boundary=check_m_boundary + ), + ) + else: + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + ) = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + full_n_block, + is_first=False, + mask_fn=partial( + mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary + ), + ) + for i in cutlass.range(1, curr_full_block_cnt): + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + ) = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + full_n_block, + mask_fn=partial( + mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary + ), + ) + + return ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + total_block_cnt == 0, + ) + + +# ============================================================================= +# Backward-specific block-sparse helpers (SM100) +# ============================================================================= +# +# In backward, iteration is transposed compared to forward: +# - Forward: outer loop over m_blocks (Q tiles), inner loop over n_blocks (KV tiles) +# - Backward: outer loop over n_blocks (KV tiles), inner loop over m_blocks (Q tiles) +# +# The backward block-sparse tensors use "Q direction" indexing: +# - q_block_cnt[batch, head, n_block] → count of m_blocks to process for this KV tile +# - q_block_idx[batch, head, n_block, :] → indices of m_blocks to process +# + + +@cute.jit +def get_total_q_block_count_bwd( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + n_block, + subtile_factor: cutlass.Constexpr = 1, + m_block_max: int = 0, +): + """Count total tile iterations for given n_block (KV tile) in backward.""" + q_block_cnt, _, full_block_cnt, _ = blocksparse_tensors + total = q_block_cnt[batch_idx, head_idx, n_block] + if const_expr(full_block_cnt is not None): + total = total + full_block_cnt[batch_idx, head_idx, n_block] + return total * subtile_factor + + +@cute.jit +def produce_block_sparse_q_loads_bwd_sm100( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + n_block, + # Pipeline states (will be returned after advancing) + producer_state_Q_LSE, + producer_state_dO_dPsum, + # Pipelines + pipeline_Q, + pipeline_LSE, + pipeline_dO, + pipeline_dPsum, + # Load functions + load_K, + load_V, + load_Q, + load_dO, + copy_stats, + # Global tensors for LSE/dPsum + gLSE, + sLSE, + gdPsum, + sdPsum, + # TMA copy bytes for extra_tx_count + tma_copy_bytes_K, + tma_copy_bytes_V, + # Flags for which loads to perform + should_load_Q: cutlass.Constexpr, + should_load_dO: cutlass.Constexpr, + # Subtiling factor and bounds + subtile_factor: cutlass.Constexpr = 1, + m_block_max: int = 0, +): + """SM100 backward block sparse loading with subtiling. + + Returns updated (producer_state_Q_LSE, producer_state_dO_dPsum). + First iteration loads K/V alongside Q/dO; subsequent iterations load only Q/dO. + """ + ( + curr_q_cnt, + curr_q_idx, + curr_full_cnt, + curr_full_idx, + loop_count, + ) = get_block_sparse_iteration_info_bwd( + blocksparse_tensors, batch_idx, head_idx, n_block, subtile_factor, m_block_max + ) + + for iter_idx in cutlass.range(loop_count, unroll=1): + m_block, _ = get_m_block_from_iter_bwd( + iter_idx, + curr_q_cnt, + curr_q_idx, + curr_full_cnt, + curr_full_idx, + subtile_factor, + m_block_max, + ) + m_block_safe = m_block + if m_block_max > 0: + m_block_safe = cutlass.min(m_block, m_block_max - 1) + + if iter_idx == 0: + # First block: load K/V alongside Q/dO + if const_expr(should_load_Q): + pipeline_Q.producer_acquire(producer_state_Q_LSE, extra_tx_count=tma_copy_bytes_K) + load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE)) + load_Q(m_block_safe, producer_state=producer_state_Q_LSE) + pipeline_Q.producer_commit(producer_state_Q_LSE) + pipeline_LSE.producer_acquire(producer_state_Q_LSE) + with cute.arch.elect_one(): + copy_stats( + gLSE[None, m_block_safe], + sLSE[None, producer_state_Q_LSE.index], + mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), + ) + producer_state_Q_LSE.advance() + if const_expr(should_load_dO): + pipeline_dO.producer_acquire( + producer_state_dO_dPsum, extra_tx_count=tma_copy_bytes_V + ) + load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum)) + load_dO(m_block_safe, producer_state=producer_state_dO_dPsum) + pipeline_dO.producer_commit(producer_state_dO_dPsum) + pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) + with cute.arch.elect_one(): + copy_stats( + gdPsum[None, m_block_safe], + sdPsum[None, producer_state_dO_dPsum.index], + mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), + ) + producer_state_dO_dPsum.advance() + else: + # Subsequent blocks: just load Q/dO (K/V already loaded) + if const_expr(should_load_Q): + pipeline_Q.producer_acquire(producer_state_Q_LSE) + load_Q(m_block_safe, producer_state=producer_state_Q_LSE) + pipeline_Q.producer_commit(producer_state_Q_LSE) + pipeline_LSE.producer_acquire(producer_state_Q_LSE) + with cute.arch.elect_one(): + copy_stats( + gLSE[None, m_block_safe], + sLSE[None, producer_state_Q_LSE.index], + mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), + ) + producer_state_Q_LSE.advance() + if const_expr(should_load_dO): + pipeline_dO.producer_acquire(producer_state_dO_dPsum) + load_dO(m_block_safe, producer_state=producer_state_dO_dPsum) + pipeline_dO.producer_commit(producer_state_dO_dPsum) + pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) + with cute.arch.elect_one(): + copy_stats( + gdPsum[None, m_block_safe], + sdPsum[None, producer_state_dO_dPsum.index], + mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), + ) + producer_state_dO_dPsum.advance() + + return producer_state_Q_LSE, producer_state_dO_dPsum + + +@cute.jit +def get_block_sparse_iteration_info_bwd( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + n_block, + subtile_factor: cutlass.Constexpr = 1, + m_block_max: int = 0, +): + """Extract block-sparse iteration info for backward pass. + + Returns (curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count). + """ + q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors + curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] + curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] + + if const_expr(full_cnt is not None): + curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] + curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] + else: + curr_full_cnt = Int32(0) + curr_full_idx = None + + sparse_block_count = curr_q_cnt + if const_expr(full_cnt is not None): + sparse_block_count = sparse_block_count + curr_full_cnt + total_count = sparse_block_count * subtile_factor + + return curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count + + +@cute.jit +def get_m_block_from_iter_bwd( + iter_idx, + curr_q_cnt, + curr_q_idx: cute.Tensor, + curr_full_cnt, + curr_full_idx: Optional[cute.Tensor], + subtile_factor: cutlass.Constexpr = 1, + m_block_max: int = 0, +): + """Derive m_block index and is_full_block flag from iteration index. + + Returns (m_block, is_full_block): + - m_block: The actual Q-tile block index + - is_full_block: True if this is a full block (no mask_mod needed) + """ + sparse_iter_idx = iter_idx // subtile_factor + subtile_offset = iter_idx % subtile_factor + + sparse_m_block = Int32(0) + is_full_block = False + if const_expr(curr_full_idx is not None): + if sparse_iter_idx < curr_q_cnt: + sparse_m_block = curr_q_idx[sparse_iter_idx] + else: + sparse_m_block = curr_full_idx[sparse_iter_idx - curr_q_cnt] + is_full_block = True + else: + sparse_m_block = curr_q_idx[sparse_iter_idx] + + return sparse_m_block * subtile_factor + subtile_offset, is_full_block + + +@cute.jit +def _load_q_do_block_sm90( + m_block, + producer_state_Q, + producer_state_dO, + pipeline_Q, + pipeline_dO, + load_K, + load_V, + load_Q, + load_dO, + load_LSE, + load_dPsum, + tma_copy_bytes_K, + tma_copy_bytes_V, + Q_stage_eq_dO_stage: cutlass.Constexpr, + load_kv: bool, +): + """Load one Q/dO block, optionally loading K/V on first iteration.""" + if load_kv: + pipeline_Q.producer_acquire(producer_state_Q, extra_tx_count=tma_copy_bytes_K) + load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q)) + else: + pipeline_Q.producer_acquire(producer_state_Q) + load_Q(m_block, producer_state=producer_state_Q) + with cute.arch.elect_one(): + load_LSE(m_block, producer_state=producer_state_Q) + + producer_state_dO_cur = ( + producer_state_dO if const_expr(not Q_stage_eq_dO_stage) else producer_state_Q + ) + if load_kv: + pipeline_dO.producer_acquire(producer_state_dO_cur, extra_tx_count=tma_copy_bytes_V) + load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur)) + else: + pipeline_dO.producer_acquire(producer_state_dO_cur) + load_dO(m_block, producer_state=producer_state_dO_cur) + with cute.arch.elect_one(): + load_dPsum(m_block, producer_state=producer_state_dO_cur) + + producer_state_Q.advance() + producer_state_dO.advance() + return producer_state_Q, producer_state_dO + + +@cute.jit +def produce_block_sparse_q_loads_bwd_sm90( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + n_block, + producer_state_Q, + producer_state_dO, + pipeline_Q, + pipeline_dO, + load_K, + load_V, + load_Q, + load_dO, + load_LSE, + load_dPsum, + tma_copy_bytes_K, + tma_copy_bytes_V, + Q_stage_eq_dO_stage: cutlass.Constexpr, + subtile_factor: cutlass.Constexpr, + m_block_max: int, +): + """SM90 backward block sparse loading with separate partial/full loops. + + K/V are loaded with the first valid block. Iterates partial blocks first, + then full blocks, matching consumer order. + + Returns updated (producer_state_Q, producer_state_dO). + """ + q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors + curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] + curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] + + if const_expr(full_cnt is not None): + curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] + curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] + else: + curr_full_cnt = Int32(0) + curr_full_idx = None + + kv_loaded = False + + for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1): + sparse_idx = iter_idx // subtile_factor + subtile_offset = iter_idx % subtile_factor + m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset + + if m_block < m_block_max: + producer_state_Q, producer_state_dO = _load_q_do_block_sm90( + m_block, + producer_state_Q, + producer_state_dO, + pipeline_Q, + pipeline_dO, + load_K, + load_V, + load_Q, + load_dO, + load_LSE, + load_dPsum, + tma_copy_bytes_K, + tma_copy_bytes_V, + Q_stage_eq_dO_stage, + load_kv=not kv_loaded, + ) + kv_loaded = True + + if const_expr(full_cnt is not None): + for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1): + sparse_idx = iter_idx // subtile_factor + subtile_offset = iter_idx % subtile_factor + m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset + + if m_block < m_block_max: + producer_state_Q, producer_state_dO = _load_q_do_block_sm90( + m_block, + producer_state_Q, + producer_state_dO, + pipeline_Q, + pipeline_dO, + load_K, + load_V, + load_Q, + load_dO, + load_LSE, + load_dPsum, + tma_copy_bytes_K, + tma_copy_bytes_V, + Q_stage_eq_dO_stage, + load_kv=not kv_loaded, + ) + kv_loaded = True + + return producer_state_Q, producer_state_dO + + +@cute.jit +def consume_block_sparse_mma_bwd_sm90( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + n_block, + consumer_state_Q, + consumer_state_dO, + mma_one_m_block_fn, + mask, + mask_mod, + is_causal: cutlass.Constexpr, + is_local: cutlass.Constexpr, + thr_mma_SdP, + softmax_scale, + seqlen, + subtile_factor: cutlass.Constexpr, + m_block_max: int, + aux_tensors=None, + fastdiv_mods=(None, None), +): + """SM90 backward block sparse MMA consumption with separate partial/full loops. + + Partial blocks are processed first (with mask_mod applied), then full blocks + (without mask_mod). This ensures mask_mod is only applied where needed. + + Returns updated (consumer_state_Q, consumer_state_dO). + """ + q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors + curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] + curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] + + if const_expr(full_cnt is not None): + curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] + curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] + else: + curr_full_cnt = Int32(0) + curr_full_idx = None + + dKV_accumulate = False + + mask_fn_partial = partial( + mask.apply_mask, + batch_idx=batch_idx, + head_idx=head_idx, + n_block=n_block, + thr_mma=thr_mma_SdP, + mask_seqlen=True, + mask_causal=is_causal, + mask_local=is_local, + mask_mod=mask_mod, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + + mask_fn_full = partial( + mask.apply_mask, + batch_idx=batch_idx, + head_idx=head_idx, + n_block=n_block, + thr_mma=thr_mma_SdP, + mask_seqlen=True, + mask_causal=is_causal, + mask_local=is_local, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + + for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1): + sparse_idx = iter_idx // subtile_factor + subtile_offset = iter_idx % subtile_factor + m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset + + if m_block < m_block_max: + consumer_state_Q, consumer_state_dO = mma_one_m_block_fn( + m_block, + consumer_state_Q, + consumer_state_dO, + mask_fn=mask_fn_partial, + dKV_accumulate=dKV_accumulate, + thr_mma_SdP=thr_mma_SdP, + batch_idx=batch_idx, + head_idx=head_idx, + n_block=n_block, + softmax_scale=softmax_scale, + seqlen=seqlen, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + dKV_accumulate = True + + if const_expr(full_cnt is not None): + for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1): + sparse_idx = iter_idx // subtile_factor + subtile_offset = iter_idx % subtile_factor + m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset + + if m_block < m_block_max: + consumer_state_Q, consumer_state_dO = mma_one_m_block_fn( + m_block, + consumer_state_Q, + consumer_state_dO, + mask_fn=mask_fn_full, + dKV_accumulate=dKV_accumulate, + thr_mma_SdP=thr_mma_SdP, + batch_idx=batch_idx, + head_idx=head_idx, + n_block=n_block, + softmax_scale=softmax_scale, + seqlen=seqlen, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + dKV_accumulate = True + + return consumer_state_Q, consumer_state_dO + + +@cute.jit +def _store_one_dQaccum_sm90( + m_block, + sdQaccum: cute.Tensor, + gdQaccum: cute.Tensor, + num_mma_warp_groups: cutlass.Constexpr, + num_threads_per_warp_group: cutlass.Constexpr, + tma_copy_bytes_dQ, +): + """Store dQaccum for a single m_block.""" + for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups): + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, + number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE, + ) + with cute.arch.elect_one(): + copy_utils.cpasync_reduce_bulk_add_f32( + sdQaccum[None, warp_group_idx].iterator, + gdQaccum[None, warp_group_idx, m_block].iterator, + tma_copy_bytes_dQ, + ) + cute.arch.cp_async_bulk_commit_group() + for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups): + cute.arch.cp_async_bulk_wait_group(num_mma_warp_groups - 1 - warp_group_idx, read=True) + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, + number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE, + ) + + +@cute.jit +def dQaccum_store_block_sparse_bwd_sm90( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + n_block, + sdQaccum: cute.Tensor, + gdQaccum: cute.Tensor, + subtile_factor: cutlass.Constexpr, + m_block_max: int, + num_mma_warp_groups: cutlass.Constexpr, + num_threads_per_warp_group: cutlass.Constexpr, + tma_copy_bytes_dQ, +): + """SM90 backward block sparse dQaccum store with separate partial/full loops. + + Iterates partial blocks first, then full blocks, matching producer/consumer order. + """ + q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors + curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] + curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] + + if const_expr(full_cnt is not None): + curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] + curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] + else: + curr_full_cnt = Int32(0) + curr_full_idx = None + + for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1): + sparse_idx = iter_idx // subtile_factor + subtile_offset = iter_idx % subtile_factor + m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset + + if m_block < m_block_max: + _store_one_dQaccum_sm90( + m_block, + sdQaccum, + gdQaccum, + num_mma_warp_groups, + num_threads_per_warp_group, + tma_copy_bytes_dQ, + ) + + if const_expr(full_cnt is not None): + for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1): + sparse_idx = iter_idx // subtile_factor + subtile_offset = iter_idx % subtile_factor + m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset + + if m_block < m_block_max: + _store_one_dQaccum_sm90( + m_block, + sdQaccum, + gdQaccum, + num_mma_warp_groups, + num_threads_per_warp_group, + tma_copy_bytes_dQ, + ) diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py new file mode 100644 index 00000000000..59b0c017f3a --- /dev/null +++ b/flash_attn/cute/block_sparsity.py @@ -0,0 +1,250 @@ +""" +Block-sparsity utilities for FlexAttention +""" + +from typing import Callable, NamedTuple, Tuple + +import cutlass.cute as cute +import torch + +from flash_attn.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor + + +def ceildiv(a: int, b: int) -> int: + return (a + b - 1) // b + + +class BlockSparseTensors(NamedTuple): + mask_block_cnt: cute.Tensor + mask_block_idx: cute.Tensor + full_block_cnt: cute.Tensor | None + full_block_idx: cute.Tensor | None + + def __new_from_mlir_values__(self, values): + if len(values) == 2: + values = (*values, None, None) + return BlockSparseTensors(*values) + + +class BlockSparseTensorsTorch(NamedTuple): + mask_block_cnt: torch.Tensor + mask_block_idx: torch.Tensor + full_block_cnt: torch.Tensor | None = None + full_block_idx: torch.Tensor | None = None + + +def _expand_sparsity_tensor( + tensor: torch.Tensor, + expected_shape: Tuple[int, ...], + tensor_name: str, + context: str | None, + hint: str | Callable[[], str] | None, +) -> torch.Tensor: + """Check if we need to expand the tensor to expected shape, and do so if possible.""" + needs_expand = tensor.shape != expected_shape + if not needs_expand: + return tensor + can_expand = all(map(lambda cur, tgt: cur == tgt or cur == 1, tensor.shape, expected_shape)) + if not can_expand: + context_clause = f" ({context})" if context else "" + resolved_hint = hint() if callable(hint) else hint + hint_clause = f" Hint: {resolved_hint}" if resolved_hint else "" + raise ValueError( + f"{tensor_name}{context_clause} with shape {tensor.shape} cannot be expanded to expected shape {expected_shape}." + f"{hint_clause}" + ) + return tensor.expand(*expected_shape) + + +def _check_and_expand_block( + name: str, + cnt: torch.Tensor | None, + idx: torch.Tensor | None, + expected_count_shape: Tuple[int, int, int], + expected_index_shape: Tuple[int, int, int, int], + context: str | None, + hint: str | Callable[[], str] | None, +) -> Tuple[torch.Tensor | None, torch.Tensor | None]: + if (cnt is None) != (idx is None): + raise ValueError( + f"{name}_block_cnt and {name}_block_idx must both be provided or both be None" + ) + if cnt is None or idx is None: + return None, None + if cnt.dtype != torch.int32 or idx.dtype != torch.int32: + raise ValueError(f"{name}_block tensors must have dtype torch.int32") + if cnt.device != idx.device: + raise ValueError(f"{name}_block_cnt and {name}_block_idx must be on the same device") + if not cnt.is_cuda or not idx.is_cuda: + raise ValueError(f"{name}_block tensors must live on CUDA") + expanded_cnt = _expand_sparsity_tensor( + cnt, expected_count_shape, f"{name}_block_cnt", context, hint + ) + expanded_idx = _expand_sparsity_tensor( + idx, expected_index_shape, f"{name}_block_idx", context, hint + ) + return expanded_cnt, expanded_idx + + +def get_block_sparse_expected_shapes( + batch_size: int, + num_head: int, + seqlen_q: int, + seqlen_k: int, + m_block_size: int, + n_block_size: int, + q_stage: int, +) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]: + """Return (expected_count_shape, expected_index_shape) for block sparse normalization.""" + m_block_size_effective = q_stage * m_block_size + expected_m_blocks = ceildiv(seqlen_q, m_block_size_effective) + expected_n_blocks = ceildiv(seqlen_k, n_block_size) + expected_count_shape = (batch_size, num_head, expected_m_blocks) + expected_index_shape = (batch_size, num_head, expected_m_blocks, expected_n_blocks) + return expected_count_shape, expected_index_shape + + +def get_block_sparse_expected_shapes_bwd( + batch_size: int, + num_head: int, + seqlen_q: int, + seqlen_k: int, + m_block_size: int, + n_block_size: int, + subtile_factor: int, +) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]: + """Return (expected_count_shape, expected_index_shape) for backward block sparse normalization. + + Backward uses Q-direction indexing (transposed from forward), where shapes are + indexed by N-blocks first, then M-blocks. The sparse_block_size_q is determined + by subtile_factor * m_block_size. + """ + sparse_block_size_q = subtile_factor * m_block_size + expected_m_blocks = ceildiv(seqlen_q, sparse_block_size_q) + expected_n_blocks = ceildiv(seqlen_k, n_block_size) + expected_count_shape = (batch_size, num_head, expected_n_blocks) + expected_index_shape = (batch_size, num_head, expected_n_blocks, expected_m_blocks) + return expected_count_shape, expected_index_shape + + +def normalize_block_sparse_tensors( + tensors: BlockSparseTensorsTorch, + *, + expected_count_shape: Tuple[int, int, int], + expected_index_shape: Tuple[int, int, int, int], + context: str | None = None, + hint: str | Callable[[], str] | None = None, +) -> BlockSparseTensorsTorch: + if tensors.mask_block_cnt is None or tensors.mask_block_idx is None: + raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.") + + mask_cnt, mask_idx = _check_and_expand_block( + "mask", + tensors.mask_block_cnt, + tensors.mask_block_idx, + expected_count_shape, + expected_index_shape, + context, + hint, + ) + if mask_cnt is None or mask_idx is None: + raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.") + + full_cnt, full_idx = _check_and_expand_block( + "full", + tensors.full_block_cnt, + tensors.full_block_idx, + expected_count_shape, + expected_index_shape, + context, + hint, + ) + if full_cnt is not None and mask_cnt.device != full_cnt.device: + raise ValueError("All block sparse tensors must be on the same device") + + return BlockSparseTensorsTorch( + mask_block_cnt=mask_cnt, + mask_block_idx=mask_idx, + full_block_cnt=full_cnt, + full_block_idx=full_idx, + ) + + +def is_block_sparsity_enabled(tensors: BlockSparseTensorsTorch) -> bool: + return any(t is not None for t in (tensors.full_block_cnt, tensors.mask_block_cnt)) + + +def get_block_sparse_broadcast_pattern( + tensors: BlockSparseTensorsTorch, +) -> Tuple[Tuple[bool, ...], ...] | None: + """Return broadcast pattern for block sparse tensors by checking actual strides. + + Returns a tuple of broadcast patterns (one per tensor) where each pattern + is a tuple of bools indicating which dims have stride=0. + This is used in compile keys to ensure kernels are recompiled when + broadcast patterns change, since CuTe's mark_layout_dynamic() keeps + stride=0 as static. + + The tensors should already be expanded/normalized before calling this function. + + Returns None if block sparsity is not enabled. + """ + if not is_block_sparsity_enabled(tensors): + return None + + patterns = [] + for tensor in ( + tensors.mask_block_cnt, + tensors.mask_block_idx, + tensors.full_block_cnt, + tensors.full_block_idx, + ): + if tensor is not None: + patterns.append(get_broadcast_dims(tensor)) + else: + patterns.append(None) + return tuple(patterns) + + +def to_cute_block_sparse_tensors( + tensors: BlockSparseTensorsTorch, enable_tvm_ffi: bool = True +) -> BlockSparseTensors | None: + """Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi""" + if not is_block_sparsity_enabled(tensors): + return None + ( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + ) = tensors + + ( + mask_block_cnt_tensor, + mask_block_idx_tensor, + ) = [ + to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi) + for t in (mask_block_cnt, mask_block_idx) + ] + ( + full_block_cnt_tensor, + full_block_idx_tensor, + ) = [ + to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi) + if t is not None + else None + for t in (full_block_cnt, full_block_idx) + ] + + return BlockSparseTensors( + mask_block_cnt_tensor, + mask_block_idx_tensor, + full_block_cnt_tensor, + full_block_idx_tensor, + ) + + +def fast_sampling(mask_mod): + """Convenience decorator to mark mask_mod as safe for 5-point fast sampling""" + mask_mod.use_fast_sampling = True + return mask_mod diff --git a/flash_attn/cute/compute_block_sparsity.py b/flash_attn/cute/compute_block_sparsity.py new file mode 100644 index 00000000000..07499422d72 --- /dev/null +++ b/flash_attn/cute/compute_block_sparsity.py @@ -0,0 +1,377 @@ +from functools import partial +from typing import Callable, Optional, Tuple + +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Boolean, Int8, Int32, const_expr + +from flash_attn.cute.block_sparsity import ( + BlockSparseTensors, + BlockSparseTensorsTorch, + to_cute_block_sparse_tensors, +) +from flash_attn.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar +from flash_attn.cute.seqlen_info import SeqlenInfoQK + + +class BlockSparsityKernel: + """Block sparsity kernel for FlexAttention. + + This kernel computes `mask_mod` for every token of each block + to determine if an n block is full, masked, or neither. + + Writes block counts and indices to a BlockSparseTensors object. + + When use_fast_sampling=True, uses 5-point sampling (4 corners + center) + which is much faster but only suitable for masks where this is sufficient. + + TODO: + - optimize mask_mod evaluation + - varlen support + - transposed tensors for bwd pass + """ + + def __init__( + self, + mask_mod: Callable, + tile_mn: Tuple[int, int], + compute_full_blocks: bool = True, + use_aux_tensors: bool = False, + use_fast_sampling: bool = False, + ): + self.mask_mod = mask_mod + self.tile_mn = tile_mn + self.compute_full_blocks = compute_full_blocks + self.use_aux_tensors = use_aux_tensors + self.use_fast_sampling = use_fast_sampling + + @cute.jit + def __call__( + self, + blocksparse_tensors: BlockSparseTensors, + seqlen_q: Int32, + seqlen_k: Int32, + aux_tensors: Optional[list] = None, + ): + self.mask_cnt, self.mask_idx, self.full_cnt, self.full_idx = blocksparse_tensors + + if const_expr(self.compute_full_blocks): + assert self.full_cnt is not None and self.full_idx is not None, ( + "full block tensors must be provided when computing full blocks" + ) + + batch_size, num_heads, num_m_blocks, num_n_blocks = self.mask_idx.shape + # launch 1 CTA per m block + grid = [num_m_blocks, num_heads, batch_size] + + if const_expr(self.use_fast_sampling): + num_threads = 5 + self.num_warps = 1 + else: + num_threads = self.tile_mn[0] + self.num_warps = (num_threads + 32 - 1) // 32 + + self.kernel( + self.mask_cnt, + self.mask_idx, + self.full_cnt, + self.full_idx, + num_n_blocks, + seqlen_q, + seqlen_k, + aux_tensors, + ).launch(grid=grid, block=[num_threads, 1, 1]) + + @cute.kernel + def kernel( + self, + mask_cnt: cute.Tensor, + mask_idx: cute.Tensor, + full_cnt: cute.Tensor, + full_idx: cute.Tensor, + num_n_blocks: Int32, + seqlen_q: Int32, + seqlen_k: Int32, + aux_tensors: Optional[list] = None, + ): + tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.warp_idx() + lane_id = cute.arch.lane_idx() + m_block, head_idx, batch_idx = cute.arch.block_idx() + + ssa = partial(scalar_to_ssa, dtype=Int32) + + seqlen = SeqlenInfoQK.create( + batch_idx, + seqlen_q, + seqlen_k, + mCuSeqlensQ=None, + mCuSeqlensK=None, + mSeqUsedQ=None, + mSeqUsedK=None, + ) + + @cute.struct + class SharedStorage: + reduction_buffer_smem: cute.struct.Align[ + cute.struct.MemRange[cutlass.Int8, 2 * self.num_warps], 1024 + ] + + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage, 16) + + reduction_buffer = storage.reduction_buffer_smem.get_tensor( + cute.make_layout((self.num_warps, 2)) + ) + + num_mask_blocks = Int32(0) + num_full_blocks = Int32(0) + + for n_block in cutlass.range(num_n_blocks, unroll_full=True): + m_base = m_block * self.tile_mn[0] + n_base = n_block * self.tile_mn[1] + + if const_expr(self.use_fast_sampling): + # Fast path: 5-point sampling (4 corners + center) + # Clamps OOB indices to nearest in bounds. + thread_result = Boolean(False) + thread_is_valid = Boolean(False) + q_idx = Int32(0) + kv_idx = Int32(0) + + if tidx == 0: + # Top-left corner (0, 0); always in bounds + q_idx = m_base + kv_idx = n_base + elif tidx == 1: + # Top-right corner + q_idx = m_base + kv_idx = cutlass.min(n_base + self.tile_mn[1] - 1, seqlen_k - 1) + elif tidx == 2: + # Bottom-left corner + q_idx = cutlass.min(m_base + self.tile_mn[0] - 1, seqlen_q - 1) + kv_idx = n_base + elif tidx == 3: + # Bottom-right corner + q_idx = cutlass.min(m_base + self.tile_mn[0] - 1, seqlen_q - 1) + kv_idx = cutlass.min(n_base + self.tile_mn[1] - 1, seqlen_k - 1) + elif tidx == 4: + # Center point + q_idx = m_base + (cutlass.min(seqlen_q - m_base, self.tile_mn[0])) // 2 + kv_idx = n_base + (cutlass.min(seqlen_k - n_base, self.tile_mn[1])) // 2 + else: + thread_is_valid = Boolean(False) + + # Check bounds and determine if this thread has a valid index pair + if tidx < 5 and q_idx < seqlen_q and kv_idx < seqlen_k: + thread_is_valid = Boolean(True) + q_idx_ssa = ssa(q_idx) + kv_idx_ssa = ssa(kv_idx) + thread_result = ssa_to_scalar( + self.mask_mod( + ssa(batch_idx), + ssa(head_idx), + q_idx_ssa, + kv_idx_ssa, + seqlen, + aux_tensors, + ) + ) + else: + thread_is_valid = Boolean(False) + + # Use vote_any_sync to see if any valid thread found unmasked or masked + # Only count results from threads that checked valid indices + has_unmasked = cute.arch.vote_any_sync(thread_result & thread_is_valid) + has_masked = cute.arch.vote_any_sync((Boolean(not thread_result)) & thread_is_valid) + + else: + # Full path: check all elements in the block + # Track if this thread's row has any masked or unmasked elements + thread_has_unmasked = Boolean(False) + thread_has_masked = Boolean(False) + thread_is_valid = Boolean(False) + + # Each thread handles 1 row + q_idx = m_base + tidx + kv_idx = Int32(0) + if tidx < self.tile_mn[0] and q_idx < seqlen_q: + thread_is_valid = Boolean(True) + q_idx_ssa = ssa(q_idx) + + # Loop over all columns in this row + for c in cutlass.range(self.tile_mn[1], unroll_full=True): + kv_idx = n_base + c + kv_idx_ssa = ssa(kv_idx) + + # Only check elements within valid sequence bounds + if kv_idx < seqlen_k: + # Direct scalar call + mask_val = ssa_to_scalar( + self.mask_mod( + ssa(batch_idx), + ssa(head_idx), + q_idx_ssa, + kv_idx_ssa, + seqlen, + aux_tensors, + ) + ) + + # Update tracking flags + if mask_val: + thread_has_unmasked = Boolean(True) + else: + thread_has_masked = Boolean(True) + + # Block-level reduction to combine results across all threads + # Only count votes from threads that checked valid indices + warp_has_unmasked_mask = cute.arch.vote_any_sync( + thread_has_unmasked & thread_is_valid + ) + warp_has_masked_mask = cute.arch.vote_any_sync(thread_has_masked & thread_is_valid) + + # lane 0 writes the ballot mask to shared memory + lane_id = tidx % 32 + if lane_id == 0: + # Store as Int8 + reduction_buffer[warp_idx, 0] = Int8(1) if warp_has_unmasked_mask else Int8(0) + reduction_buffer[warp_idx, 1] = Int8(1) if warp_has_masked_mask else Int8(0) + + cute.arch.sync_threads() + + # Thread 0 ORs all warp results together + has_unmasked = Boolean(False) + has_masked = Boolean(False) + if tidx == 0: + for w in cutlass.range(self.num_warps): + if reduction_buffer[w, 0]: + has_unmasked = Boolean(True) + if reduction_buffer[w, 1]: + has_masked = Boolean(True) + + # Only thread 0 updates the output arrays (common to both paths) + if tidx == 0: + # Block classification based on what we found: + # - If has_masked and has_unmasked: partial block (needs masking) + # - If only has_unmasked: full block (no masking needed) + # - If only has_masked: skip this block entirely + is_partial = Boolean(has_masked and has_unmasked) + is_full = Boolean(has_unmasked and (not has_masked)) + + if is_partial: + mask_idx[batch_idx, head_idx, m_block, num_mask_blocks] = n_block + num_mask_blocks += 1 + elif is_full and const_expr(self.compute_full_blocks): + full_idx[batch_idx, head_idx, m_block, num_full_blocks] = n_block + num_full_blocks += 1 + + # Only thread 0 writes back the counts + if tidx == 0: + mask_cnt[batch_idx, head_idx, m_block] = num_mask_blocks + if const_expr(self.compute_full_blocks): + full_cnt[batch_idx, head_idx, m_block] = num_full_blocks + + +def compute_block_sparsity( + tile_m, + tile_n, + batch_size, + num_heads, + seqlen_q, + seqlen_k, + mask_mod: Callable, + aux_tensors: Optional[list], # list[cute.Tensor] + device, + compute_full_blocks: bool = True, + use_fast_sampling: bool = False, +) -> Tuple[BlockSparseTensors, BlockSparseTensorsTorch]: + """ + Computes block sparsity for a given `mask_mod`. + + Args: + tile_m: The tile size for the m dimension. + tile_n: The tile size for the n dimension. + batch_size: The batch size. + num_heads: The number of heads. + seqlen_q: The sequence length for the query. + seqlen_k: The sequence length for the key. + mask_mod: The `mask_mod` callable to use. + aux_tensors: A list of auxiliary tensors. + device: The device to use. + compute_full_blocks: Whether to compute full blocks. If False, only partially-masked blocks are computed. + use_fast_sampling: Whether to use 5-point sampling (4 corners + center). This is much faster, but only suitable for masks where this check is sufficient. + + Returns: + A tuple of `BlockSparseTensors` and `BlockSparseTensorsTorch`. + """ + # Check if mask_mod is marked as suitable for 5-point fast sampling + use_fast_sampling = getattr(mask_mod, "use_fast_sampling", use_fast_sampling) + + num_m_blocks = (seqlen_q + tile_m - 1) // tile_m + num_n_blocks = (seqlen_k + tile_n - 1) // tile_n + + mask_block_cnt = torch.zeros( + (batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32 + ) + mask_block_idx = torch.zeros( + (batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32 + ) + full_block_cnt = ( + torch.zeros((batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32) + if compute_full_blocks + else None + ) + full_block_idx = ( + torch.zeros( + (batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32 + ) + if compute_full_blocks + else None + ) + + blocksparse_tensors_torch = BlockSparseTensorsTorch( + mask_block_cnt=mask_block_cnt, + mask_block_idx=mask_block_idx, + full_block_cnt=full_block_cnt, + full_block_idx=full_block_idx, + ) + + mask_mod_hash = hash_callable(mask_mod) + blocksparse_tensors = to_cute_block_sparse_tensors( + blocksparse_tensors_torch, enable_tvm_ffi=True + ) + + compile_key = ( + tile_m, + tile_n, + mask_mod_hash, + compute_full_blocks, + aux_tensors is not None, + use_fast_sampling, + ) + if compile_key not in compute_block_sparsity.compile_cache: + kernel = BlockSparsityKernel( + mask_mod, + tile_mn=(tile_m, tile_n), + compute_full_blocks=compute_full_blocks, + use_aux_tensors=aux_tensors is not None, + use_fast_sampling=use_fast_sampling, + ) + + compute_block_sparsity.compile_cache[compile_key] = cute.compile( + kernel, blocksparse_tensors, seqlen_q, seqlen_k, aux_tensors, options="--enable-tvm-ffi" + ) + + compute_block_sparsity.compile_cache[compile_key]( + blocksparse_tensors_torch, + seqlen_q, + seqlen_k, + aux_tensors, + ) + + return blocksparse_tensors, blocksparse_tensors_torch + + +compute_block_sparsity.compile_cache = {} diff --git a/flash_attn/cute/copy_utils.py b/flash_attn/cute/copy_utils.py new file mode 100644 index 00000000000..cfdcbdb80a0 --- /dev/null +++ b/flash_attn/cute/copy_utils.py @@ -0,0 +1,340 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + +import math +from typing import Optional, Type, Callable + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute.nvgpu import cpasync +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm +import cutlass.pipeline + + +@dsl_user_op +def cvt_copy( + atom: cute.CopyAtom, + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + loc=None, + ip=None, + **kwargs, +) -> None: + assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem + if const_expr(src.element_type != dst.element_type): + src_cvt = cute.make_fragment_like(src, dst.element_type, loc=loc, ip=ip) + src_cvt.store(src.load().to(dst.element_type)) + src = src_cvt + cute.copy(atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +@dsl_user_op +def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: + dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip) + cute.autovec_copy(src, dst, loc=loc, ip=ip) + return dst + + +@dsl_user_op +def get_copy_atom( + dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None +) -> cute.CopyAtom: + num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width)) + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + + +@dsl_user_op +def make_tmem_copy( + tmem_copy_atom: cute.CopyAtom, num_wg: int = 1, *, loc=None, ip=None +) -> cute.CopyAtom: + num_dp, num_bits, num_rep, _ = sm100_utils.get_tmem_copy_properties(tmem_copy_atom) + assert num_dp == 32 + assert num_bits == 32 + tiler_mn = (cute.make_layout((128 * num_rep * num_wg // 32, 32), stride=(32, 1)),) + layout_tv = cute.make_layout( + ((32, 4, num_wg), (num_rep, 32)), stride=((0, 1, 4 * num_rep), (4, 4 * num_rep * num_wg)) + ) + return cute.make_tiled_copy(tmem_copy_atom, layout_tv, tiler_mn) + + +@dsl_user_op +def copy( + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + num_copy_elems: int = 1, + is_async: bool = False, + loc=None, + ip=None, + **kwargs, +) -> None: + copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async) + cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +def tiled_copy_1d( + dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False +) -> cute.TiledCopy: + num_copy_bits = num_copy_elems * dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + thr_layout = cute.make_layout(num_threads) + val_layout = cute.make_layout(num_copy_elems) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +def tiled_copy_2d( + dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False +) -> cute.TiledCopy: + num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width + copy_elems = num_copy_bits // dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + gmem_threads_per_row = major_mode_size // copy_elems + assert num_threads % gmem_threads_per_row == 0 + thr_layout = cute.make_ordered_layout( + (num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, copy_elems)) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +@dsl_user_op +def atomic_add_fp32x4( + a: Float32, b: Float32, c: Float32, d: Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None +) -> None: + gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() + # cache_hint = cutlass.Int64(0x12F0000000000000) + llvm.inline_asm( + None, + [ + gmem_ptr_i64, + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + Float32(c).ir_value(loc=loc, ip=ip), + Float32(d).ir_value(loc=loc, ip=ip), + ], + # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()], + "{\n\t" + # ".reg .b128 abcd;\n\t" + # "mov.b128 abcd, {$1, $2, $3, $4};\n\t" + ".reg .v4 .f32 abcd;\n\t" + # "mov.b128 abcd, {$1, $2, $3, $4};\n\t" + "mov.f32 abcd.x, $1;\n\t" + "mov.f32 abcd.y, $2;\n\t" + "mov.f32 abcd.z, $3;\n\t" + "mov.f32 abcd.w, $4;\n\t" + "red.global.add.v4.f32 [$0], abcd;\n\t" + # "red.global.add.L2::cache_hint.v4.f32 [$0], abcd, 0x14F0000000000000;\n\t" + "}\n", + # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;", + # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;", + "l,f,f,f,f", + # "l,f,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def set_block_rank( + smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None +) -> Int32: + """Map the given smem pointer to the address at another CTA rank in the cluster.""" + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + return Int32( + llvm.inline_asm( + T.i32(), + [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()], + "mapa.shared::cluster.u32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def store_shared_remote_fp32x4( + a: Float32, + b: Float32, + c: Float32, + d: Float32, + smem_ptr: cute.Pointer, + mbar_ptr: cute.Pointer, + peer_cta_rank_in_cluster: Int32, + *, + loc=None, + ip=None, +) -> None: + remote_smem_ptr_i32 = set_block_rank( + smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + remote_mbar_ptr_i32 = set_block_rank( + mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + llvm.inline_asm( + None, + [ + remote_smem_ptr_i32, + remote_mbar_ptr_i32, + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + Float32(c).ir_value(loc=loc, ip=ip), + Float32(d).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .v4 .f32 abcd;\n\t" + "mov.f32 abcd.x, $2;\n\t" + "mov.f32 abcd.y, $3;\n\t" + "mov.f32 abcd.z, $4;\n\t" + "mov.f32 abcd.w, $5;\n\t" + "st.async.shared::cluster.mbarrier::complete_tx::bytes.v4.f32 [$0], abcd, [$1];\n\t" + "}\n", + "r,r,f,f,f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def cpasync_bulk_g2s( + gmem_ptr: cute.Pointer, + smem_ptr: cute.Pointer, + tma_bar_ptr: cute.Pointer, + size: int | Int32, + *, + loc=None, + ip=None, +): + gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + mbar_ptr_i32 = tma_bar_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [gmem_ptr_i64, smem_ptr_i32, mbar_ptr_i32, Int32(size).ir_value()], + "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [$1], [$0], $3, [$2];", + "l,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def cpasync_reduce_bulk_add_f32( + smem_ptr: cute.Pointer, + gmem_ptr: cute.Pointer, + store_bytes: int | Int32, + *, + loc=None, + ip=None, +): + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + # cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST + llvm.inline_asm( + None, + [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()], + "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;", + "l,r,r", + # [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()], + # "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;", + # "l,r,r,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +def cpasync_bulk_get_copy_fn( + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + single_stage: bool = False, + **kwargs, +) -> Callable: + # src_is_smem = const_expr( + # isinstance(src_tensor.iterator, cute.Pointer) + # and src_tensor.memspace == cute.AddressSpace.smem + # ) + group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0)) + group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + src = cute.group_modes(src_tensor, 0, group_rank_src) + dst = cute.group_modes(dst_tensor, 0, group_rank_dst) + + def copy_bulk(src_idx, dst_idx, **new_kwargs): + size = const_expr(cute.size(src.shape[:-1]) * src.element_type.width // 8) + cpasync_bulk_g2s( + src[None, src_idx].iterator, + dst[None, dst_idx].iterator, + size=size, + **new_kwargs, + **kwargs, + ) + + def copy_bulk_single_stage(**new_kwargs): + size = const_expr(cute.size(src.shape) * src.element_type.width // 8) + cpasync_bulk_g2s(src.iterator, dst.iterator, size=size, **new_kwargs, **kwargs) + + return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage + + +def tma_get_copy_fn( + atom: cute.CopyAtom, + cta_coord: cute.Coord, + cta_layout: cute.Layout, + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + filter_zeros: bool = False, + single_stage: bool = False, + **kwargs, +) -> Callable: + src_is_smem = const_expr( + isinstance(src_tensor.iterator, cute.Pointer) + and src_tensor.memspace == cute.AddressSpace.smem + ) + smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor) + group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0)) + group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + s, g = cpasync.tma_partition( + atom, + cta_coord, + cta_layout, + cute.group_modes(smem_tensor, 0, group_rank_smem), + cute.group_modes(gmem_tensor, 0, group_rank_gmem), + ) + if const_expr(filter_zeros): + s = cute.filter_zeros(s) + g = cute.filter_zeros(g) + src, dst = (s, g) if src_is_smem else (g, s) + + def copy_tma(src_idx, dst_idx, **new_kwargs): + cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs) + + def copy_tma_single_stage(**new_kwargs): + cute.copy(atom, src, dst, **new_kwargs, **kwargs) + + return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g + + +def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync): + def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs): + copy( + src_idx=src_idx, + dst_idx=producer_state.index, + tma_bar_ptr=pipeline.producer_get_barrier(producer_state), + **new_kwargs, + ) + + return copy_fn diff --git a/flash_attn/cute/cute_dsl_utils.py b/flash_attn/cute/cute_dsl_utils.py new file mode 100644 index 00000000000..14723872b85 --- /dev/null +++ b/flash_attn/cute/cute_dsl_utils.py @@ -0,0 +1,144 @@ +# Copyright (c) 2025, Tri Dao. + +import os +import pathlib +from typing import Tuple +from functools import partial, lru_cache +from dataclasses import dataclass, fields + +import torch + +try: + from triton.tools.disasm import extract +except ImportError: + extract = None + +import cutlass +import cutlass.cute as cute +from cutlass.base_dsl.typing import JitArgument +from cutlass.cutlass_dsl import NumericMeta +from cutlass.cute.runtime import from_dlpack + +StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None)) + + +load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data +cute_compile_og = cute.compile + + +torch2cute_dtype_map = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, +} + + +@lru_cache +def get_max_active_clusters(cluster_size): + return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size) + + +@lru_cache +def get_device_capacity(device: torch.device = None) -> Tuple[int, int]: + return torch.cuda.get_device_capability(device) + + +@dataclass +class ParamsBase: + def __extract_mlir_values__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + values, self._values_pos = [], [] + for obj in non_constexpr_fields: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) + + +@dataclass +class ArgumentsBase(JitArgument): + def __c_pointers__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + c_ptrs = [] + for obj in non_constexpr_fields: + if hasattr(obj, "__c_pointers__"): + c_ptrs.extend(obj.__c_pointers__()) + return c_ptrs + + def __get_mlir_types__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + types, self._values_pos = [], [] + for obj in non_constexpr_fields: + if hasattr(obj, "__get_mlir_types__"): + obj_types = obj.__get_mlir_types__() + types.extend(obj_types) + self._values_pos.append(len(obj_types)) + else: + self._values_pos.append(0) + return types + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) + + +def load_cubin_module_data_patched(cubin_data, filepath): + pathlib.Path(filepath).write_bytes(cubin_data) + return load_cubin_module_data_og(cubin_data) + + +def cute_compile_patched(*args, **kwargs): + """A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set.""" + cubin_path = os.getenv("CUTE_CUBIN_PATH", None) + if cubin_path is not None: + cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial( + load_cubin_module_data_patched, filepath=cubin_path + ) + output = cute_compile_og(*args, **kwargs) + if cubin_path is not None: + cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og + if extract is not None: + sass = extract(cubin_path, None) + pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass) + return output + + +def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True): + """Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1.""" + tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi) + if fully_dynamic: + return tensor.mark_layout_dynamic() + if leading_dim == -1: + leading_dim = t.ndim - 1 + return tensor.mark_layout_dynamic(leading_dim=leading_dim) + + +def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]: + """Return tuple of bools indicating which dims have stride=0 (broadcast). + + This is useful for compile keys since CuTe's mark_layout_dynamic() keeps + stride=0 as static, meaning kernels compiled with different broadcast + patterns are not interchangeable. + """ + return tuple(s == 0 for s in tensor.stride()) diff --git a/flash_attn/cute/fast_math.py b/flash_attn/cute/fast_math.py new file mode 100644 index 00000000000..c56ea89e798 --- /dev/null +++ b/flash_attn/cute/fast_math.py @@ -0,0 +1,21 @@ +# Copyright (c) 2025, Tri Dao. + +import cutlass +import cutlass.cute as cute +from cutlass import Int32 + + +@cute.jit +def clz(x: Int32) -> Int32: + # for i in cutlass.range_constexpr(32): + # if (1 << (31 - i)) & x: + # return Int32(i) + # return Int32(32) + # Early exit is not supported yet + res = Int32(32) + done = False + for i in cutlass.range(32): + if ((1 << (31 - i)) & x) and not done: + res = Int32(i) + done = True + return res diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py new file mode 100644 index 00000000000..763e824e55b --- /dev/null +++ b/flash_attn/cute/flash_bwd.py @@ -0,0 +1,1262 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/mainloop_bwd_sm80.hpp +# from Cutlass C++ to Cute-DSL. +import math +from types import SimpleNamespace +from typing import Type, Callable, Optional +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, warp +from cutlass import Float32, Int32 +import cutlass.utils as utils_basic + +from flash_attn.cute import ampere_helpers as sm80_utils +from flash_attn.cute import utils +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.tile_scheduler import ParamsBase, SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments + + +class FlashAttentionBackwardSm80: + def __init__( + self, + dtype: Type[cutlass.Numeric], + head_dim: int, + head_dim_v: Optional[int] = None, + qhead_per_kvhead: int = 1, + m_block_size: int = 64, + n_block_size: int = 128, + num_stages_Q: int = 2, + num_stages_dO: int = 2, + num_threads: int = 256, + pack_gqa: bool = False, + is_causal: bool = False, + SdP_swapAB: bool = False, + dKV_swapAB: bool = False, + dQ_swapAB: bool = False, + AtomLayoutMSdP: int = 1, + AtomLayoutNdKV: int = 8, + AtomLayoutMdQ: int = 1, + V_in_regs: bool = False, + ): + """Initializes the configuration for a flash attention v2 kernel. + + All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension + should be a multiple of 8. + + :param head_dim: head dimension + :type head_dim: int + :param m_block_size: m block size + :type m_block_size: int + :param n_block_size: n block size + :type n_block_size: int + :param num_threads: number of threads + :type num_threads: int + :param is_causal: is causal + """ + self.dtype = dtype + # padding head_dim to a multiple of 16 as k_block_size + hdim_multiple_of = 32 + self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + head_dim_v = head_dim_v if head_dim_v is not None else head_dim + self.same_hdim_kv = head_dim == head_dim_v + self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) + # Can save registers (and hence be faster) if we don't have to check hdim predication + self.check_hdim_oob = head_dim != self.head_dim_padded + self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded + self.qhead_per_kvhead = qhead_per_kvhead + self.m_block_size = m_block_size + self.n_block_size = n_block_size + self.num_threads = num_threads + self.pack_gqa = pack_gqa + self.is_causal = is_causal + self.num_stages_Q = num_stages_Q + self.num_stages_dO = num_stages_dO + self.SdP_swapAB = SdP_swapAB + self.dKV_swapAB = dKV_swapAB + self.dQ_swapAB = dQ_swapAB + self.AtomLayoutMSdP = AtomLayoutMSdP + self.AtomLayoutNdKV = AtomLayoutNdKV + self.AtomLayoutMdQ = AtomLayoutMdQ + num_mma_warps = self.num_threads // cute.arch.WARP_SIZE + self.Mma_dKV_is_RS = AtomLayoutMSdP == 1 and AtomLayoutNdKV == num_mma_warps and SdP_swapAB and not dKV_swapAB + self.V_in_regs = V_in_regs + self.share_QV_smem = V_in_regs + + @staticmethod + def can_implement( + dtype, head_dim, head_dim_v, m_block_size, n_block_size, num_stages_Q, num_stages_dO, + num_threads, is_causal, + V_in_regs=False + ) -> bool: + """Check if the kernel can be implemented with the given parameters. + + :param dtype: data type + :type dtype: cutlass.Numeric + :param head_dim: head dimension + :type head_dim: int + :param m_block_size: m block size + :type m_block_size: int + :param n_block_size: n block size + :type n_block_size: int + :param num_threads: number of threads + :type num_threads: int + :param is_causal: is causal + :type is_causal: bool + + :return: True if the kernel can be implemented, False otherwise + :rtype: bool + """ + if dtype not in [cutlass.Float16, cutlass.BFloat16]: + return False + if head_dim % 8 != 0: + return False + if head_dim_v % 8 != 0: + return False + if n_block_size % 16 != 0: + return False + if num_threads % 32 != 0: + return False + # Check if block size setting is out of shared memory capacity + # Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size + smem_usage_Q = m_block_size * head_dim * num_stages_Q * 2 + smem_usage_dO = m_block_size * head_dim_v * num_stages_dO * 2 + smem_usage_K = n_block_size * head_dim * 2 + smem_usage_V = n_block_size * head_dim_v * 2 + smem_usage_QV = (smem_usage_Q + smem_usage_V) if not V_in_regs else max(smem_usage_Q, smem_usage_V) + smem_usage = smem_usage_QV + smem_usage_dO + smem_usage_K + smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_80") + if smem_usage > smem_capacity: + return False + return True + + def _check_type( + self, + mQ_type: Type[cutlass.Numeric], + mK_type: Type[cutlass.Numeric], + mV_type: Type[cutlass.Numeric], + mdO_type: Type[cutlass.Numeric], + mLSE_type: Type[cutlass.Numeric], + mdPsum_type: Type[cutlass.Numeric], + mdQaccum_type: Type[cutlass.Numeric], + mdK_type: Type[cutlass.Numeric], + mdV_type: Type[cutlass.Numeric], + mCuSeqlensQ_type: Type[cutlass.Numeric] | None, + mCuSeqlensK_type: Type[cutlass.Numeric] | None, + mSeqUsedQ_type: Type[cutlass.Numeric] | None, + mSeqUsedK_type: Type[cutlass.Numeric] | None, + ): + if cutlass.const_expr(not (mQ_type == mK_type == mV_type == mdO_type)): + raise TypeError("All tensors must have the same data type") + if cutlass.const_expr(self.qhead_per_kvhead == 1): + if cutlass.const_expr(not (mdK_type == mdV_type == mQ_type)): + raise TypeError("mdK and mdV tensors must have the same data type as mQ") + else: + if cutlass.const_expr(not (mdK_type == mdV_type == cutlass.Float32)): + raise TypeError("mdKaccum and mdVaccum tensors must have the data type Float32") + if cutlass.const_expr(not mQ_type in [cutlass.Float16, cutlass.BFloat16]): + raise TypeError("Only Float16 or BFloat16 is supported") + if cutlass.const_expr(not mLSE_type in [cutlass.Float32]): + raise TypeError("LSE tensor must be Float32") + if cutlass.const_expr(not mdPsum_type in [cutlass.Float32]): + raise TypeError("dPsum tensor must be Float32") + if cutlass.const_expr(not mdQaccum_type in [cutlass.Float32]): + raise TypeError("dQaccum tensor must be Float32") + if cutlass.const_expr(mCuSeqlensQ_type not in [None, cutlass.Int32]): + raise TypeError("cuSeqlensQ tensor must be Int32") + if cutlass.const_expr(mCuSeqlensK_type not in [None, cutlass.Int32]): + raise TypeError("cuSeqlensK tensor must be Int32") + if cutlass.const_expr(mSeqUsedQ_type not in [None, cutlass.Int32]): + raise TypeError("SeqUsedQ tensor must be Int32") + if cutlass.const_expr(mSeqUsedK_type not in [None, cutlass.Int32]): + raise TypeError("SeqUsedK tensor must be Int32") + assert mQ_type == self.dtype + + def _setup_attributes(self): + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory layout: Q/K/V + # /////////////////////////////////////////////////////////////////////////////// + sQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.head_dim_padded) + self.sQ_layout = cute.tile_to_shape( + sQ_layout_atom, (self.m_block_size, self.head_dim_padded, self.num_stages_Q), (0, 1, 2), + ) + sK_layout_atom = sQ_layout_atom + self.sK_layout = cute.tile_to_shape( + sK_layout_atom, (self.n_block_size, self.head_dim_padded), (0, 1), + ) + sV_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.head_dim_v_padded) + self.sV_layout = cute.tile_to_shape( + sV_layout_atom, (self.n_block_size, self.head_dim_v_padded), (0, 1), + ) + sdO_layout_atom = sV_layout_atom + self.sdO_layout = cute.tile_to_shape( + sdO_layout_atom, (self.m_block_size, self.head_dim_v_padded, self.num_stages_dO), (0, 1, 2), + ) + # TODO: do we set swizzle to be 3 here explicitly? + sPdS_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.n_block_size) + self.sPdS_layout = cute.tile_to_shape( + sPdS_layout_atom, (self.m_block_size, self.n_block_size), (0, 1), + ) + # We set stride to be multiple of 64 so that if ShuffleLSE, even if threads read from sLSE but out of bounds, + # it's still a valid smem address. + self.sLSE_layout = cute.make_layout( + (self.m_block_size, self.num_stages_Q), + stride=(1, cute.round_up(self.m_block_size, 64)), + ) + sLSEMma_layout = cute.make_layout( + (self.m_block_size, self.n_block_size, self.num_stages_Q), + stride=(1, 0, cute.round_up(self.m_block_size, 64)), + ) + sLSEMma_layout_transposed = cute.make_layout( + (self.n_block_size, self.m_block_size, self.num_stages_Q), + stride=(0, 1, cute.round_up(self.m_block_size, 64)), + ) + self.sLSEMma_layout = sLSEMma_layout if not self.SdP_swapAB else sLSEMma_layout_transposed + + # /////////////////////////////////////////////////////////////////////////////// + # GMEM Tiled copy: + # /////////////////////////////////////////////////////////////////////////////// + # Thread layouts for copies + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.dtype.width + # atom_async_copy: async copy atom for QKV load + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + self.dtype, + num_bits_per_copy=universal_copy_bits, + ) + # atom_universal_copy: universal copy atom for O store + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=universal_copy_bits, + ) + # tQK_layout: thread layout for QK load + tQK_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems + assert self.num_threads % tQK_shape_dim_1 == 0, "num_threads must be divisible by tQK_shape_dim_1" + tQK_layout = cute.make_ordered_layout( + (self.num_threads // tQK_shape_dim_1, tQK_shape_dim_1), order=(1, 0), + ) + # Do we need to check if we overshot kBlockM when we load Q? + self.is_even_m_smem_q = self.m_block_size % tQK_layout.shape[0] == 0 + # Do we need to check if we overshot kBlockN when we load K? + self.is_even_n_smem_k = self.n_block_size % tQK_layout.shape[0] == 0 + tVdO_shape_dim_1 = sV_layout_atom.outer.shape[1] // async_copy_elems + assert self.num_threads % tVdO_shape_dim_1 == 0, "num_threads must be divisible by tVdO_shape_dim_1" + tVdO_layout = cute.make_ordered_layout( + (self.num_threads // tVdO_shape_dim_1, tVdO_shape_dim_1), order=(1, 0), + ) + # Do we need to check if we overshot kBlockN when we load V? + self.is_even_n_smem_v = self.n_block_size % tVdO_layout.shape[0] == 0 + self.is_even_m_smem_do = self.m_block_size % tVdO_layout.shape[0] == 0 + + # Value layouts for copies + vQKVdO_layout = cute.make_layout((1, async_copy_elems)) + + # gmem_tiled_copy_QK: tiled copy for QK load + self.gmem_tiled_copy_QK = cute.make_tiled_copy_tv(atom_async_copy, tQK_layout, vQKVdO_layout) + self.gmem_tiled_copy_VdO = cute.make_tiled_copy_tv(atom_async_copy, tVdO_layout, vQKVdO_layout) + self.gmem_tiled_copy_dK = cute.make_tiled_copy_tv(atom_universal_copy, tQK_layout, vQKVdO_layout) + self.gmem_tiled_copy_dV = cute.make_tiled_copy_tv(atom_universal_copy, tVdO_layout, vQKVdO_layout) + async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width + + # I think we wouldn't require this with smarter padding + if cutlass.const_expr(not self.varlen_q): + async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width + atom_async_copy_accum = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + cutlass.Float32, + num_bits_per_copy=universal_copy_bits, + ) + else: + async_copy_elems_accum = 1 + atom_async_copy_accum = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Float32, + num_bits_per_copy=cutlass.Float32.width, + ) + self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv( + atom_async_copy_accum, + cute.make_layout(self.num_threads), + cute.make_layout(async_copy_elems_accum), + ) + self.gmem_tiled_copy_dQaccum = cute.make_tiled_copy_tv( + cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=cutlass.Float32.width + ), + cute.make_layout(self.num_threads), + cute.make_layout(1) + ) + if cutlass.const_expr(self.qhead_per_kvhead > 1): + self.gmem_tiled_copy_dK = self.gmem_tiled_copy_dQaccum + self.gmem_tiled_copy_dV = self.gmem_tiled_copy_dQaccum + + def _get_tiled_mma(self): + num_mma_warps = self.num_threads // 32 + AtomLayoutSdP = (self.AtomLayoutMSdP, num_mma_warps // self.AtomLayoutMSdP, 1) if cutlass.const_expr(not self.SdP_swapAB) else (num_mma_warps // self.AtomLayoutMSdP, self.AtomLayoutMSdP, 1) + tiled_mma_sdp = cute.make_tiled_mma( + warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), + AtomLayoutSdP, + permutation_mnk=(AtomLayoutSdP[0] * 16, AtomLayoutSdP[1] * 16, 16), + ) + AtomLayoutdKV = (self.AtomLayoutNdKV, num_mma_warps // self.AtomLayoutNdKV, 1) if cutlass.const_expr(not self.dKV_swapAB) else (num_mma_warps // self.AtomLayoutNdKV, self.AtomLayoutNdKV, 1) + tiled_mma_dkv = cute.make_tiled_mma( + warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), + AtomLayoutdKV, + permutation_mnk=(AtomLayoutdKV[0] * 16, AtomLayoutdKV[1] * 16, 16), + ) + AtomLayoutdQ = (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) if cutlass.const_expr(not self.dQ_swapAB) else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1) + tiled_mma_dq = cute.make_tiled_mma( + warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), + AtomLayoutdQ, + permutation_mnk=(AtomLayoutdQ[0] * 16, AtomLayoutdQ[1] * 16, 16), + ) + return tiled_mma_sdp, tiled_mma_dkv, tiled_mma_dq + + def _get_shared_storage_cls(self): + sQ_struct, sK_struct, sV_struct, sdO_struct = [ + cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 1024] + for layout in (self.sQ_layout, self.sK_layout, self.sV_layout, self.sdO_layout) + ] + cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) + sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024] + sLSE_struct, sdPsum_struct = [ + cute.struct.Align[cute.struct.MemRange[cutlass.Float32, cute.cosize(layout)], 128] + for layout in (self.sLSE_layout, self.sLSE_layout) + ] + sP_struct, sdS_struct = [ + cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 128] + for layout in (self.sPdS_layout, self.sPdS_layout) + ] + + @cute.struct + class SharedStorageSeparateQV: + sK: sK_struct + sV: sV_struct + sQ: sQ_struct + sdO: sdO_struct + sLSE: sLSE_struct + sdPsum: sdPsum_struct + sP: sP_struct + sdS: sdS_struct + # TODO: the case where there's no sP + + @cute.struct + class SharedStorageSharedQV: + sK: sK_struct + sV: sV_struct + sQ: sQV_struct + sdO: sdO_struct + sLSE: sLSE_struct + sdPsum: sdPsum_struct + sP: sP_struct + sdS: sdS_struct + + return SharedStorageSeparateQV if cutlass.const_expr(not self.share_QV_smem) else SharedStorageSharedQV + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mdO: cute.Tensor, + mLSE: cute.Tensor, + mdPsum: cute.Tensor, + mdQaccum: cute.Tensor, + mdK: cute.Tensor, + mdV: cute.Tensor, + softmax_scale: cutlass.Float32, + stream: cuda.CUstream, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + softcap: Float32 | float | None = None, + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, + mdQ_semaphore: Optional[cute.Tensor] = None, + ): + assert mdQ_semaphore is None, "semaphore not supported yet" + # Get the data type and check if it is fp16 or bf16 + self._check_type(*(t.element_type if t is not None else None + for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK))) + # Assume all strides are divisible by 128 bits except the last stride + # Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints) + new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) if not isinstance(s, int) or s != 0 else s for s in t.stride[:-1]), t.stride[-1]) + mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) if t is not None else None for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)] + self.varlen_q = (mCuSeqlensQ is not None) + self._setup_attributes() + SharedStorage = self._get_shared_storage_cls() + tiled_mma_sdp, tiled_mma_dkv, tiled_mma_dq = self._get_tiled_mma() + + num_head = mQ.shape[1] if cutlass.const_expr(mCuSeqlensQ is not None) else mQ.shape[2] + + if cutlass.const_expr(mCuSeqlensK is not None): + TileScheduler = SingleTileVarlenScheduler + num_batch = mCuSeqlensK.shape[0] - 1 + else: + TileScheduler = SingleTileScheduler + num_batch = mK.shape[0] + + # Uses seqlen k, etc. since main bwd kernel's blocks are over n + tile_sched_args = TileSchedulerArguments( + num_block=cute.ceil_div(mK.shape[1], self.n_block_size), + num_head=num_head, + num_batch=num_batch, + num_splits=1, + seqlen_k=0, + headdim=mK.shape[2], + headdim_v=mV.shape[2], + total_q=mK.shape[0], + tile_shape_mn=(self.n_block_size, self.m_block_size), + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1, + mCuSeqlensQ=mCuSeqlensK, + mSeqUsedQ=mSeqUsedK, + ) + + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + + softmax_scale_log2 = softmax_scale * math.log2(math.e) + self.kernel( + mQ, + mK, + mV, + mdO, + mLSE, + mdPsum, + mdQaccum, + mdK, + mdV, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, + softmax_scale, + softmax_scale_log2, + self.sQ_layout, + self.sK_layout, + self.sV_layout, + self.sdO_layout, + self.sPdS_layout, + self.sLSE_layout, + self.sLSEMma_layout, + self.gmem_tiled_copy_QK, + self.gmem_tiled_copy_VdO, + self.gmem_tiled_copy_dK, + self.gmem_tiled_copy_dV, + self.gmem_tiled_copy_LSE, + self.gmem_tiled_copy_dQaccum, + tiled_mma_sdp, + tiled_mma_dkv, + tiled_mma_dq, + SharedStorage, + tile_sched_params, + TileScheduler, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + smem=SharedStorage.size_in_bytes(), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mdO: cute.Tensor, + mLSE: cute.Tensor, + mdPsum: cute.Tensor, + mdQaccum: cute.Tensor, + mdK: cute.Tensor, + mdV: cute.Tensor, + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], + softmax_scale: cutlass.Float32, + softmax_scale_log2: cutlass.Float32, + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sdO_layout: cute.ComposedLayout, + sPdS_layout: cute.ComposedLayout, + sLSE_layout: cute.Layout, + sLSEMma_layout: cute.Layout, + gmem_tiled_copy_QK: cute.TiledCopy, + gmem_tiled_copy_VdO: cute.TiledCopy, + gmem_tiled_copy_dK: cute.TiledCopy, + gmem_tiled_copy_dV: cute.TiledCopy, + gmem_tiled_copy_LSE: cute.TiledCopy, + gmem_tiled_copy_dQaccum: cute.TiledCopy, + tiled_mma_sdp: cute.TiledMma, + tiled_mma_dkv: cute.TiledMma, + tiled_mma_dq: cute.TiledMma, + SharedStorage: cutlass.Constexpr, + tile_sched_params: ParamsBase, + TileScheduler: cutlass.Constexpr[Callable], + ): + # Thread index, block index + tidx, _, _ = cute.arch.thread_idx() + + tile_scheduler = TileScheduler.create(tile_sched_params) + work_tile = tile_scheduler.initial_work_tile_info() + + n_block, head_idx, batch_idx, _ = work_tile.tile_idx + + if work_tile.is_valid_tile: + seqlen = SeqlenInfoQK.create(batch_idx, mQ.shape[1], mK.shape[1], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK) + + m_block_max = cute.ceil_div(seqlen.seqlen_q, self.m_block_size) + m_block_min = 0 + if cutlass.const_expr(self.is_causal): + m_block_min = max( + (n_block * self.n_block_size + seqlen.seqlen_q - seqlen.seqlen_k) // self.m_block_size, + m_block_min, + ) + # TODO: return early if m_block_max == 0 + + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # /////////////////////////////////////////////////////////////////////////////// + blkQ_shape = (self.m_block_size, self.head_dim_padded) + blkK_shape = (self.n_block_size, self.head_dim_padded) + blkV_shape = (self.n_block_size, self.head_dim_v_padded) + blkdO_shape = (self.m_block_size, self.head_dim_v_padded) + + if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + mQ_cur = mQ[batch_idx, None, head_idx, None] + mLSE_cur = mLSE[batch_idx, head_idx, None] + mdO_cur = mdO[batch_idx, None, head_idx, None] + mdPsum_cur = mdPsum[batch_idx, head_idx, None] + mdQaccum_cur = mdQaccum[batch_idx, head_idx, None] + else: + padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size + mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, head_idx, None]) + mLSE_cur = cute.domain_offset((padded_offset_q,), mLSE[head_idx, None]) + mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None]) + mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[head_idx, None]) + mdQaccum_cur = cute.domain_offset((padded_offset_q * self.head_dim_padded,), mdQaccum[head_idx, None]) + head_idx_kv = head_idx // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else head_idx + + if cutlass.const_expr(not seqlen.has_cu_seqlens_k): + mK_cur, mV_cur = [t[batch_idx, None, head_idx_kv, None] for t in (mK, mV)] + else: + mK_cur, mV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, head_idx_kv, None]) for t in (mK, mV)] + + # (m_block_size, head_dim, m_block) + gQ = cute.local_tile(mQ_cur, blkQ_shape, (None, 0)) + # (n_block_size, head_dim) + gK = cute.local_tile(mK_cur, blkK_shape, (n_block, 0)) + # (n_block_size, head_dim_v) + gV = cute.local_tile(mV_cur, blkV_shape, (n_block, 0)) + # (m_block_size, head_dim_v, m_block) + gdO = cute.local_tile(mdO_cur, blkdO_shape, (None, 0)) + gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (None,)) + gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (None,)) + gdQaccum = cute.local_tile(mdQaccum_cur, (self.m_block_size * self.head_dim_padded,), (None,)) + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sQ = storage.sQ.get_tensor(sQ_layout) + sK = storage.sK.get_tensor(sK_layout) + if cutlass.const_expr(not self.share_QV_smem): + sV = storage.sV.get_tensor(sV_layout) + else: + sV = cute.make_tensor(cute.recast_ptr(sQ.iterator, dtype=self.dtype), sV_layout) + sdO = storage.sdO.get_tensor(sdO_layout) + sP = storage.sP.get_tensor(sPdS_layout) + sdS = storage.sdS.get_tensor(sPdS_layout) + sLSE = storage.sLSE.get_tensor(sLSE_layout) + sdPsum = storage.sdPsum.get_tensor(sLSE_layout) + sLSEMma = storage.sLSE.get_tensor(sLSEMma_layout) + sdPsumMma = storage.sdPsum.get_tensor(sLSEMma_layout) + + # Transpose view of tensors for tiled mma + sQt, sdOt, sKt, sPt, sdSt = [utils.transpose_view(t) for t in (sQ, sdO, sK, sP, sdS)] + + gmem_thr_copy_QK = gmem_tiled_copy_QK.get_slice(tidx) + gmem_thr_copy_VdO = gmem_tiled_copy_VdO.get_slice(tidx) + gmem_thr_copy_lse = gmem_tiled_copy_LSE.get_slice(tidx) + gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx) + # (CPY_Atom, CPY_M, CPY_K, m_block) + tQgQ = gmem_thr_copy_QK.partition_S(gQ) + tQsQ = gmem_thr_copy_QK.partition_D(sQ) + # (CPY_Atom, CPY_N, CPY_K) + tKgK = gmem_thr_copy_QK.partition_S(gK) + tKsK = gmem_thr_copy_QK.partition_D(sK) + # (CPY_Atom, CPY_N, CPY_K) + tVgV = gmem_thr_copy_VdO.partition_S(gV) + tVsV = gmem_thr_copy_VdO.partition_D(sV) + # (CPY_Atom, CPY_M, CPY_K, m_block) + tdOgdO = gmem_thr_copy_VdO.partition_S(gdO) + tdOsdO = gmem_thr_copy_VdO.partition_D(sdO) + tLSEgLSE = gmem_thr_copy_lse.partition_S(gLSE) + tLSEsLSE = gmem_thr_copy_lse.partition_D(sLSE) + tLSEgdPsum = gmem_thr_copy_lse.partition_S(gdPsum) + tLSEsdPsum = gmem_thr_copy_lse.partition_D(sdPsum) + tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) + + # /////////////////////////////////////////////////////////////////////////////// + # Tile MMA compute thread partitions and allocate accumulators + # /////////////////////////////////////////////////////////////////////////////// + thr_mma_sdp = tiled_mma_sdp.get_slice(tidx) + thr_mma_dkv = tiled_mma_dkv.get_slice(tidx) + thr_mma_dq = tiled_mma_dq.get_slice(tidx) + acc_shape_dK = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_padded)) + acc_shape_dV = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_v_padded)) + acc_dK = cute.make_fragment(acc_shape_dK, cutlass.Float32) + acc_dV = cute.make_fragment(acc_shape_dV, cutlass.Float32) + acc_dK.fill(0.0) + acc_dV.fill(0.0) + + tSrQ = utils.mma_make_fragment_A(sQ[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB) + tSrK = utils.mma_make_fragment_B(sK, thr_mma_sdp, swapAB=self.SdP_swapAB) + tdPrdO = utils.mma_make_fragment_A(sdO[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB) + tdPrV = utils.mma_make_fragment_B(sV, thr_mma_sdp, swapAB=self.SdP_swapAB) + tdVrP = utils.mma_make_fragment_A(sPt, thr_mma_dkv, swapAB=self.dKV_swapAB) + tdVrdO = utils.mma_make_fragment_B(sdOt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB) + tdKrdS = utils.mma_make_fragment_A(sdSt, thr_mma_dkv, swapAB=self.dKV_swapAB) + tdKrQ = utils.mma_make_fragment_B(sQt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB) + tdQrdS = utils.mma_make_fragment_A(sdS, thr_mma_dq, swapAB=self.dQ_swapAB) + tdQrK = utils.mma_make_fragment_B(sKt, thr_mma_dq, swapAB=self.dQ_swapAB) + + LSEslice = (None, 0, None) if cutlass.const_expr(not self.SdP_swapAB) else (0, None, None) + tSsLSEMma = utils.make_acc_tensor_mn_view(thr_mma_sdp.partition_C(sLSEMma))[LSEslice] + tSsdPsumMma = utils.make_acc_tensor_mn_view(thr_mma_sdp.partition_C(sdPsumMma))[LSEslice] + + # /////////////////////////////////////////////////////////////////////////////// + # Smem copy atom tiling + # /////////////////////////////////////////////////////////////////////////////// + smem_copy_atom = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), self.dtype, + ) + smem_copy_atom_transposed = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), self.dtype, + ) + smem_thr_copy_QdO = utils.make_tiled_copy_A( + smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB + ).get_slice(tidx) + smem_thr_copy_KV = utils.make_tiled_copy_B( + smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB + ).get_slice(tidx) + # TODO: should this be smem_copy_atom_transposed? + smem_thr_copy_PdSt = utils.make_tiled_copy_A( + smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB + ).get_slice(tidx) + smem_thr_copy_QdOt = utils.make_tiled_copy_B( + smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB + ).get_slice(tidx) + smem_thr_copy_dS = utils.make_tiled_copy_A( + smem_copy_atom, tiled_mma_dq, swapAB=self.dQ_swapAB + ).get_slice(tidx) + smem_thr_copy_Kt = utils.make_tiled_copy_B( + smem_copy_atom_transposed, tiled_mma_dq, swapAB=self.dQ_swapAB + ).get_slice(tidx) + # TODO: what's the number of bits? What if SdP_swapAB + r2s_thr_copy_PdS = cute.make_tiled_copy_C( + cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width + ), + tiled_mma_sdp, + ).get_slice(tidx) + + tSsQ = smem_thr_copy_QdO.partition_S(sQ) + tdPsdO = smem_thr_copy_QdO.partition_S(sdO) + tSsK = smem_thr_copy_KV.partition_S(sK) + tdPsV = smem_thr_copy_KV.partition_S(sV) + tdVsPt = smem_thr_copy_PdSt.partition_S(sPt) + tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt) + tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt) + tdKsQt = smem_thr_copy_QdOt.partition_S(sQt) + tdQsdS = smem_thr_copy_dS.partition_S(sdS) + tdQsKt = smem_thr_copy_Kt.partition_S(sKt) + tPsP = r2s_thr_copy_PdS.partition_D(sP) + tdSsdS = r2s_thr_copy_PdS.partition_D(sdS) + + # /////////////////////////////////////////////////////////////////////////////// + # Predicate: Mark indices that need to copy when problem_shape isn't a multiple + # of tile_shape + # /////////////////////////////////////////////////////////////////////////////// + # Construct identity layout for KV + cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tQcQ = gmem_thr_copy_QK.partition_S(cQ) + t0QcQ = gmem_thr_copy_QK.get_slice(0).partition_S(cQ) + if cutlass.const_expr(self.head_dim_padded == self.head_dim_v_padded): + tdOcdO = tQcQ + t0dOcdO = t0QcQ + else: + cdO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) + tdOcdO = gmem_thr_copy_VdO.partition_S(cdO) + t0dOcdO = gmem_thr_copy_VdO.get_slice(0).partition_S(cdO) + cLSE = cute.make_identity_tensor((self.m_block_size,)) + tLSEcLSE = gmem_thr_copy_lse.partition_S(cLSE) + + # Allocate predicate tensors for m and n, here we only allocate the tile of k, and + # use "if" on the mn dimension. + # This is to reduce register pressure and gets 2-3% performance gain. + + d_head = mQ.shape[cute.rank(mQ) - 1] + d_head_v = mdO.shape[cute.rank(mdO) - 1] + + tQpQ = utils.predicate_k(tQcQ, limit=d_head) + if cutlass.const_expr(self.same_hdim_kv): + tdOpdO = tQpQ + else: + tdOpdO = utils.predicate_k(tdOcdO, limit=d_head_v) + + # group parameters for compute_one_m_block + mma_params = SimpleNamespace( + thr_mma_sdp=thr_mma_sdp, thr_mma_dkv=thr_mma_dkv, thr_mma_dq=thr_mma_dq, + tSrQ=tSrQ, tSrK=tSrK, tdPrdO=tdPrdO, tdPrV=tdPrV, + tdVrP=tdVrP, tdVrdO=tdVrdO, tdKrdS=tdKrdS, tdKrQ=tdKrQ, + tdQrdS=tdQrdS, tdQrK=tdQrK, + acc_dK=acc_dK, acc_dV=acc_dV, + ) + smem_copy_params = SimpleNamespace( + smem_thr_copy_QdO=smem_thr_copy_QdO, + smem_thr_copy_KV=smem_thr_copy_KV, + smem_thr_copy_PdSt=smem_thr_copy_PdSt, + smem_thr_copy_QdOt=smem_thr_copy_QdOt, + smem_thr_copy_dS=smem_thr_copy_dS, + smem_thr_copy_Kt=smem_thr_copy_Kt, + r2s_thr_copy_PdS=r2s_thr_copy_PdS, + tSsQ=tSsQ, tSsK=tSsK, tdPsdO=tdPsdO, tdPsV=tdPsV, + tSsLSEMma=tSsLSEMma, tSsdPsumMma=tSsdPsumMma, + tPsP=tPsP, tdSsdS=tdSsdS, + tdVsPt=tdVsPt, tdVsdOt=tdVsdOt, tdKsdSt=tdKsdSt, tdKsQt=tdKsQt, + tdQsdS=tdQsdS, tdQsKt=tdQsKt, + ) + gmem_copy_params = SimpleNamespace( + gmem_thr_copy_dQaccum=gmem_thr_copy_dQaccum, tdQgdQaccum=tdQgdQaccum + ) + load_Q_LSE = partial( + self.load_Q_LSE, gmem_tiled_copy_QK, gmem_tiled_copy_LSE, + tQgQ, tQsQ, tQcQ, t0QcQ, tQpQ, + tLSEgLSE, tLSEsLSE, tLSEcLSE, seqlen=seqlen.seqlen_q + ) + load_dO_dPsum = partial( + self.load_dO_dPsum, gmem_tiled_copy_VdO, gmem_tiled_copy_LSE, + tdOgdO, tdOsdO, tdOcdO, t0dOcdO, tdOpdO, + tLSEgdPsum, tLSEsdPsum, tLSEcLSE, seqlen=seqlen.seqlen_q + ) + compute_one_m_block = partial( + self.compute_one_m_block, mma_params=mma_params, + smem_copy_params=smem_copy_params, gmem_copy_params=gmem_copy_params, + load_Q_LSE=load_Q_LSE, load_dO_dPsum=load_dO_dPsum, + m_block_max=m_block_max, + softmax_scale_log2=softmax_scale_log2, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Prologue + # /////////////////////////////////////////////////////////////////////////////// + # Start async loads of the last mn-tile, where we take care of the mn residue + self.load_V(gmem_thr_copy_VdO, tVgV, tVsV, n_block, seqlen=seqlen.seqlen_k, + headdim=d_head_v) + if cutlass.const_expr(self.V_in_regs): + cute.arch.cp_async_commit_group() + self.load_K(gmem_thr_copy_QK, tKgK, tKsK, n_block, seqlen=seqlen.seqlen_k, + headdim=d_head) + cute.arch.cp_async_commit_group() + + if cutlass.const_expr(self.V_in_regs): + cute.arch.cp_async_wait_group(1) + cute.arch.barrier() + tdPrV_copy_view = smem_thr_copy_KV.retile(tdPrV) + cute.copy(smem_thr_copy_KV, tdPsV, tdPrV_copy_view) + # Sync to avoid loading Q to smem_q, which overlaps with smem_v + cute.arch.barrier() + + m_block = m_block_min + assert self.num_stages_Q >= self.num_stages_dO + for stage in cutlass.range_constexpr(self.num_stages_Q): + if cutlass.const_expr(self.num_stages_Q == 1 or stage < self.num_stages_Q - 1): + if stage == 0 or m_block + stage < m_block_max: + load_Q_LSE(m_block + stage, smem_pipe_write_q=stage) + cute.arch.cp_async_commit_group() + if cutlass.const_expr(stage < self.num_stages_dO): + if stage == 0 or m_block + stage < m_block_max: + load_dO_dPsum(m_block + stage, smem_pipe_write_q=stage) + cute.arch.cp_async_commit_group() + + # /////////////////////////////////////////////////////////////////////////////// + # Mainloop + # /////////////////////////////////////////////////////////////////////////////// + # Start processing of the first n-block. + mask = AttentionMask(self.m_block_size, self.n_block_size, seqlen.seqlen_q, seqlen.seqlen_k) + mask_fn = partial( + mask.apply_mask, n_block=n_block, thr_mma=thr_mma_sdp, + mask_seqlen=True, mask_causal=self.is_causal + ) + smem_pipe_read_q = cutlass.Int32(0) + smem_pipe_read_do = cutlass.Int32(0) + smem_pipe_write_q = cutlass.Int32(self.num_stages_Q - 1) + smem_pipe_write_do = cutlass.Int32(0) + for m_tile in cutlass.range(m_block_min, m_block_max, unroll=1): + compute_one_m_block( + m_tile, smem_pipe_read_q, smem_pipe_read_do, smem_pipe_write_q, smem_pipe_write_do, + mask_fn=mask_fn, + ) + smem_pipe_read_q = self.advance_pipeline(smem_pipe_read_q, self.num_stages_Q) + smem_pipe_read_do = self.advance_pipeline(smem_pipe_read_do, self.num_stages_dO) + smem_pipe_write_q = self.advance_pipeline(smem_pipe_write_q, self.num_stages_Q) + smem_pipe_write_do = self.advance_pipeline(smem_pipe_write_do, self.num_stages_dO) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + # If GQA, we scale dK in the postprocessing kernel instead + if cutlass.const_expr(self.qhead_per_kvhead == 1): + acc_dK.store(acc_dK.load() * softmax_scale) + # reuse sK and sV data iterator + sdK = cute.make_tensor(sK.iterator, sK_layout) + sdV = cute.make_tensor(sV.iterator, sV_layout) + self.epilogue( + acc_dK, acc_dV, mdK, mdV, sdK, sdV, + gmem_tiled_copy_dK, gmem_tiled_copy_dV, tiled_mma_dkv, + tidx, n_block, head_idx, batch_idx, seqlen, d_head, d_head_v + ) + + @cute.jit + def compute_one_m_block( + self, + m_block: cutlass.Int32, + smem_pipe_read_q: cutlass.Int32, + smem_pipe_read_do: cutlass.Int32, + smem_pipe_write_q: cutlass.Int32, + smem_pipe_write_do: cutlass.Int32, + mma_params: SimpleNamespace, + smem_copy_params: SimpleNamespace, + gmem_copy_params: SimpleNamespace, + load_Q_LSE: Callable, + load_dO_dPsum: Callable, + m_block_max: cutlass.Int32, + softmax_scale_log2: cutlass.Float32, + mask_fn: Optional[Callable] = None, + ): + def load_Q_next(): + m_block_next = m_block + (self.num_stages_Q - 1 if cutlass.const_expr(self.num_stages_Q > 1) else 1) + if m_block_next < m_block_max: + load_Q_LSE(m_block_next, smem_pipe_write_q) + cute.arch.cp_async_commit_group() + + def load_dO_next(): + if m_block + self.num_stages_dO < m_block_max: + load_dO_dPsum(m_block + self.num_stages_dO, smem_pipe_write_do) + cute.arch.cp_async_commit_group() + + # MMA S + acc_shape_SdP = mma_params.thr_mma_sdp.partition_shape_C( + (self.m_block_size, self.n_block_size) if cutlass.const_expr(not self.SdP_swapAB) else (self.n_block_size, self.m_block_size) + ) + acc_S = cute.make_fragment(acc_shape_SdP, cutlass.Float32) + acc_S.fill(0.0) + cute.arch.cp_async_wait_group(1 if cutlass.const_expr(self.num_stages_Q > 1) else 0) + cute.arch.barrier() + sm80_utils.gemm( + mma_params.thr_mma_sdp, acc_S, mma_params.tSrQ, mma_params.tSrK, + smem_copy_params.tSsQ[None, None, None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], + smem_copy_params.tSsK, + smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV, + swap_AB=self.SdP_swapAB, + ) + tLSErLSE = cute.make_fragment_like(smem_copy_params.tSsLSEMma[None, 0]) + cute.autovec_copy( + smem_copy_params.tSsLSEMma[None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], tLSErLSE + ) + if cutlass.const_expr(mask_fn is not None): + mask_fn(acc_S, m_block=m_block) + acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) + bidx = 0 + # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn) + # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 1: cute.print_tensor(tLSErLSE) + assert cute.size(acc_S_mn, mode=[0]) == cute.size(tLSErLSE) + for r in cutlass.range(cute.size(acc_S_mn, mode=[0]), unroll_full=True): + acc_S_mn[r, None].store(utils.exp2f(acc_S_mn[r, None].load() * softmax_scale_log2 - tLSErLSE[r])) + # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn) + + # MMA dP + acc_dP = cute.make_fragment(acc_shape_SdP, cutlass.Float32) + acc_dP.fill(0.0) + cute.arch.cp_async_wait_group(1 if cutlass.const_expr(self.num_stages_dO > 1) else 0) + cute.arch.barrier() + sm80_utils.gemm( + mma_params.thr_mma_sdp, acc_dP, mma_params.tdPrdO, mma_params.tdPrV, + smem_copy_params.tdPsdO[None, None, None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0], + smem_copy_params.tdPsV, + smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV, + hook_fn=load_Q_next if cutlass.const_expr(self.num_stages_Q > 1) else None, + swap_AB=self.SdP_swapAB, + ) + tLSErdPsum = cute.make_fragment_like(smem_copy_params.tSsdPsumMma[None, 0]) + cute.autovec_copy( + smem_copy_params.tSsdPsumMma[None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0], tLSErdPsum + ) + acc_dP_mn = utils.make_acc_tensor_mn_view(acc_dP) + # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn) + assert cute.size(acc_dP_mn, mode=[0]) == cute.size(tLSErdPsum) + for r in cutlass.range(cute.size(acc_dP_mn, mode=[0]), unroll_full=True): + acc_dP_mn[r, None].store(acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r])) + # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn) + rP = cute.make_fragment_like(acc_S, self.dtype) + rP.store(acc_S.load().to(self.dtype)) + if cutlass.const_expr(not self.Mma_dKV_is_RS): + tPrP = smem_copy_params.r2s_thr_copy_PdS.retile(rP) # ((Atom,AtomNum), MMA_N, MMA_N) + cute.copy(smem_copy_params.r2s_thr_copy_PdS, tPrP, smem_copy_params.tPsP) + rdS = cute.make_fragment_like(acc_dP, self.dtype) + rdS.store(acc_dP.load().to(self.dtype)) + if cutlass.const_expr(not self.Mma_dKV_is_RS): + cute.arch.barrier() # Make sure P is written + # For hdim 64, It's faster to write to smem_dS first before the dV gemm + if cutlass.const_expr(not self.Mma_dKV_is_RS): + tdSrdS = smem_copy_params.r2s_thr_copy_PdS.retile(rdS) + cute.copy(smem_copy_params.r2s_thr_copy_PdS, tdSrdS, smem_copy_params.tdSsdS) + if cutlass.const_expr(self.Mma_dKV_is_RS): + tdVrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout)) + else: + tdVrP = mma_params.tdVrP + + # MMA dK + sm80_utils.gemm( + mma_params.thr_mma_dkv, mma_params.acc_dV, tdVrP, mma_params.tdVrdO, + smem_copy_params.tdVsPt, + smem_copy_params.tdVsdOt[None, None, None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0], + smem_copy_params.smem_thr_copy_PdSt, smem_copy_params.smem_thr_copy_QdOt, + A_in_regs=self.Mma_dKV_is_RS, + swap_AB=self.dKV_swapAB, + ) + # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(mma_params.acc_dV) + cute.arch.barrier() # Make sure dS is written + + # MMA dQ + def dQ_mma(hook_fn): + acc_shape_dQ = mma_params.thr_mma_dq.partition_shape_C( + (self.m_block_size, self.head_dim_padded) if cutlass.const_expr(not self.dQ_swapAB) else (self.head_dim_padded, self.m_block_size) + ) + acc_dQ = cute.make_fragment(acc_shape_dQ, cutlass.Float32) + acc_dQ.fill(0.0) + sm80_utils.gemm( + mma_params.thr_mma_dq, acc_dQ, mma_params.tdQrdS, mma_params.tdQrK, + smem_copy_params.tdQsdS, smem_copy_params.tdQsKt, + smem_copy_params.smem_thr_copy_dS, smem_copy_params.smem_thr_copy_Kt, + swap_AB=self.dQ_swapAB, + hook_fn=hook_fn + ) + # ((1, 1), num_elements) + acc_dQ_atomic = gmem_copy_params.gmem_thr_copy_dQaccum.retile(acc_dQ) + tdQgdQaccum_atomic = gmem_copy_params.tdQgdQaccum[None, None, m_block] + assert cute.size(acc_dQ_atomic) == cute.size(tdQgdQaccum_atomic) + for i in cutlass.range(cute.size(acc_dQ_atomic), unroll_full=True): + utils.atomic_add_fp32(acc_dQ_atomic[i], utils.elem_pointer(tdQgdQaccum_atomic, i)) + # utils.atomic_add_fp32(acc_dQ[i], tdQgdQaccum_atomic.iterator + i * tdQgdQaccum_atomic.stride[1]) + # if cute.arch.thread_idx()[0] == 64 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dQ) + + # If num_stages_Q == 1, we want to do Mma_dK first so we can start loading Q for the next iteration + if cutlass.const_expr(self.num_stages_Q > 1): + dQ_mma(load_dO_next) + + # MMA dK + if cutlass.const_expr(self.Mma_dKV_is_RS): + tdKrdS = cute.make_tensor(rdS.iterator, utils.convert_layout_acc_frgA(rdS.layout)) + else: + tdKrdS = mma_params.tdKrdS + sm80_utils.gemm( + mma_params.thr_mma_dkv, mma_params.acc_dK, tdKrdS, mma_params.tdKrQ, + smem_copy_params.tdKsdSt, + smem_copy_params.tdKsQt[None, None, None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], + smem_copy_params.smem_thr_copy_PdSt, smem_copy_params.smem_thr_copy_QdOt, + A_in_regs=self.Mma_dKV_is_RS, + swap_AB=self.dKV_swapAB, + hook_fn=load_dO_next if cutlass.const_expr(self.num_stages_Q == 1) else None, + ) + # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(mma_params.acc_dK) + if cutlass.const_expr(self.num_stages_Q == 1): + cute.arch.barrier() + dQ_mma(load_Q_next) + + @cute.jit + def epilogue( + self, + acc_dK: cute.Tensor, + acc_dV: cute.Tensor, + mdK: cute.Tensor, + mdV: cute.Tensor, + sdK: cute.Tensor, + sdV: cute.Tensor, + gmem_tiled_copy_dK: cute.TiledCopy, + gmem_tiled_copy_dV: cute.TiledCopy, + tiled_mma: cute.TiledMma, + tidx: cutlass.Int32, + n_block: cutlass.Int32, + num_head: cutlass.Int32, + batch_size: cutlass.Int32, + seqlen: SeqlenInfoQK, + d_head: cutlass.Int32, + d_head_v: cutlass.Int32 + ): + rdV = cute.make_fragment_like(acc_dV, self.dtype) + rdV.store(acc_dV.load().to(self.dtype)) + rdK = cute.make_fragment_like(acc_dK, self.dtype) + rdK.store(acc_dK.load().to(self.dtype)) + gmem_thr_copy_dK = gmem_tiled_copy_dK.get_slice(tidx) + gmem_thr_copy_dV = gmem_tiled_copy_dV.get_slice(tidx) + + batch_idx = batch_size + head_idx_kv = num_head // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else num_head + + if cutlass.const_expr(self.qhead_per_kvhead == 1): + # Make sure all threads have finished reading K and V, otherwise we get racy dQ + # because smem_q could be changed. + cute.arch.barrier() + # smem copy atom for dKV + smem_copy_atom_dKV = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width + ) + smem_thr_copy_dKV = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma).get_slice(tidx) + taccdVrdV = smem_thr_copy_dKV.retile(rdV) + taccdKrdK = smem_thr_copy_dKV.retile(rdK) + taccdVsdV = smem_thr_copy_dKV.partition_D(sdV) + taccdKsdK = smem_thr_copy_dKV.partition_D(sdK) + # copy acc O from rmem to smem with the smem copy atom + cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV) + cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK) + + + if cutlass.const_expr(not seqlen.has_cu_seqlens_k): + mdK_cur, mdV_cur = [t[batch_idx, None, head_idx_kv, None] for t in (mdK, mdV)] + else: + mdK_cur, mdV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, head_idx_kv, None]) for t in (mdK, mdV)] + + blkdK_shape = (self.n_block_size, self.head_dim_padded) + blkdV_shape = (self.n_block_size, self.head_dim_v_padded) + gdK = cute.local_tile(mdK_cur, blkdK_shape, (n_block, 0)) + gdV = cute.local_tile(mdV_cur, blkdV_shape, (n_block, 0)) + tdKsdK = gmem_thr_copy_dK.partition_S(sdK) + tdKgdK = gmem_thr_copy_dK.partition_D(gdK) + tdVsdV = gmem_thr_copy_dV.partition_S(sdV) + tdVgdV = gmem_thr_copy_dV.partition_D(gdV) + tdKrdK = cute.make_fragment_like(tdKgdK, self.dtype) + tdVrdV = cute.make_fragment_like(tdVgdV, self.dtype) + # sync before all smem stores are done. + cute.arch.barrier() + # load acc dK and dV from smem to rmem for wider vectorization + # Need to check OOB when reading from smem if kBlockN isn't evenly tiled + # TODO + cute.autovec_copy(tdKsdK, tdKrdK) + cute.autovec_copy(tdVsdV, tdVrdV) + + cdK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded)) + tdKcdK = gmem_thr_copy_dK.partition_S(cdK) + t0dKcdK = gmem_tiled_copy_dK.get_slice(0).partition_S(cdK) + if cutlass.const_expr(self.head_dim_padded == self.head_dim_v_padded): + tdVcdV = tdKcdK + t0dVcdV = t0dKcdK + else: + cdV = cute.make_identity_tensor((self.n_block_size, self.head_dim_v_padded)) + tdVcdV = gmem_thr_copy_dV.partition_S(cdV) + t0dVcdV = gmem_tiled_copy_dV.get_slice(0).partition_S(cdV) + tdKpdK = utils.predicate_k(tdKcdK, limit=d_head) + if cutlass.const_expr(self.same_hdim_kv): + tdVpdV = tdKpdK + else: + tdVpdV = utils.predicate_k(tdVcdV, limit=d_head_v) + # copy acc dK and acc_dV from rmem to gmem + for rest_m in cutlass.range_constexpr(cute.size(tdKrdK.shape[1])): + if t0dKcdK[0, rest_m, 0][0] < seqlen.seqlen_k - n_block * self.n_block_size - tdKcdK[0][0]: + cute.copy( + gmem_tiled_copy_dK, + tdKrdK[None, rest_m, None], + tdKgdK[None, rest_m, None], + pred=tdKpdK[None, rest_m, None] if cutlass.const_expr(self.check_hdim_oob) else None, + ) + for rest_m in cutlass.range_constexpr(cute.size(tdVrdV.shape[1])): + if t0dVcdV[0, rest_m, 0][0] < seqlen.seqlen_k - n_block * self.n_block_size - tdVcdV[0][0]: + cute.copy( + gmem_tiled_copy_dV, + tdVrdV[None, rest_m, None], + tdVgdV[None, rest_m, None], + pred=tdVpdV[None, rest_m, None] if cutlass.const_expr(self.check_hdim_v_oob) else None, + ) + + else: # qhead_per_kvhead > 1, do atomic add + # For Sm90, we need to sync to avoid racy writes to smem_q + # For Sm80, we don't need to sync since we're not touching smem + head_idx_kv = num_head // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else num_head + + if cutlass.const_expr(not seqlen.has_cu_seqlens_k): + mdK_cur, mdV_cur = [t[batch_idx, head_idx_kv, None] for t in (mdK, mdV)] + else: + padded_offset_k = seqlen.offset_k + batch_idx * self.n_block_size + mdK_cur = cute.domain_offset((padded_offset_k * self.head_dim_padded,), mdK[head_idx_kv, None]) + mdV_cur = cute.domain_offset((padded_offset_k * self.head_dim_v_padded,), mdV[head_idx_kv, None]) + + gdV = cute.local_tile(mdV_cur, (self.n_block_size * self.head_dim_v_padded,), (n_block,)) + gdK = cute.local_tile(mdK_cur, (self.n_block_size * self.head_dim_padded,), (n_block,)) + tdVgdVaccum = gmem_thr_copy_dV.partition_S(gdV) + tdKgdKaccum = gmem_thr_copy_dK.partition_S(gdK) + acc_dV_atomic = gmem_thr_copy_dV.retile(acc_dV) + acc_dK_atomic = gmem_thr_copy_dK.retile(acc_dK) + assert cute.size(acc_dV_atomic) == cute.size(tdVgdVaccum) + assert cute.size(acc_dK_atomic) == cute.size(tdKgdKaccum) + for i in cutlass.range(cute.size(acc_dV_atomic), unroll_full=True): + utils.atomic_add_fp32(acc_dV_atomic[i], utils.elem_pointer(tdVgdVaccum, i)) + for i in cutlass.range(cute.size(acc_dK_atomic), unroll_full=True): + utils.atomic_add_fp32(acc_dK_atomic[i], utils.elem_pointer(tdKgdKaccum, i)) + + @cute.jit + def advance_pipeline(self, pipeline_index, num_stages: cutlass.Constexpr): + return pipeline_index + 1 if pipeline_index < num_stages - 1 else 0 + + @cute.jit + def load_K( + self, + gmem_thr_copy: cute.TiledCopy, + tKgK: cute.Tensor, + tKsK: cute.Tensor, + block: cutlass.Int32, + seqlen: cutlass.Int32, + headdim: cutlass.Int32, + ): + cK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded)) + tKcK = gmem_thr_copy.partition_S(cK) + t0KcK = gmem_thr_copy.get_slice(0).partition_S(cK) + tKpK = utils.predicate_k(tKcK, limit=headdim) + for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])): + # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked + if self.is_even_n_smem_k or n < cute.size(tKsK.shape[1]) - 1 or tKcK[0, n, 0][0] < self.n_block_size: + # Instead of using tKcK, we using t0KcK and subtract the offset from the limit + # (seqlen - block * kBlockN). This is because the entries of t0KcK are known at compile time. + predicate_n = t0KcK[0, n, 0][0] < seqlen - block * self.n_block_size - tKcK[0][0] + predicate = cute.make_fragment_like(tKpK[None, 0, None]) + for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): + for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): + predicate[i, k] = (tKpK[i, n, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_n + cute.copy( + gmem_thr_copy, tKgK[None, n, None], tKsK[None, n, None], pred=predicate, + ) + # We need to clear the sK smem tiles since we'll use sKt for mma_dq + + @cute.jit + def load_V( + self, + gmem_thr_copy: cute.TiledCopy, + tVgV: cute.Tensor, + tVsV: cute.Tensor, + block: cutlass.Int32, + seqlen: cutlass.Int32, + headdim: cutlass.Int32, + ): + cV = cute.make_identity_tensor((self.n_block_size, self.head_dim_v_padded)) + tVcV = gmem_thr_copy.partition_S(cV) + t0VcV = gmem_thr_copy.get_slice(0).partition_S(cV) + tVpV = utils.predicate_k(tVcV, limit=headdim) + for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])): + # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked + if self.is_even_n_smem_v or n < cute.size(tVsV.shape[1]) - 1 or tVcV[0, n, 0][0] < self.n_block_size: + # Instead of using tVcV, we using t0VcV and subtract the offset from the limit + # (seqlen - block * kBlockN). This is because the entries of t0VcV are known at compile time. + predicate_n = t0VcV[0, n, 0][0] < seqlen - block * self.n_block_size - tVcV[0][0] + predicate = cute.make_fragment_like(tVpV[None, 0, None]) + for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): + for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): + predicate[i, k] = (tVpV[i, n, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_n + cute.copy( + gmem_thr_copy, tVgV[None, n, None], tVsV[None, n, None], pred=predicate, + ) + + @cute.jit + def load_Q_LSE( + self, + gmem_tiled_copy_Q: cute.TiledCopy, + gmem_tiled_copy_LSE: cute.TiledCopy, + tQgQ: cute.Tensor, + tQsQ: cute.Tensor, + tQcQ: cute.Tensor, + t0QcQ: cute.Tensor, + tQpQ: cute.Tensor, + tLSEgLSE: cute.Tensor, + tLSEsLSE: cute.Tensor, + tLSEcLSE: cute.Tensor, + block: cutlass.Int32, + smem_pipe_write_q: cutlass.Int32, + seqlen: cutlass.Int32, + ): + for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): + # If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked + if self.is_even_m_smem_q or m < cute.size(tQsQ.shape[1]) - 1 or tQcQ[0, m, 0][0] < self.m_block_size: + # Instead of using tQcQ, we using t0QcQ and subtract the offset from the limit + # (seqlen - block * kBlockM). This is because the entries of t0QcQ are known at compile time. + predicate_m = t0QcQ[0, m, 0][0] < seqlen - block * self.m_block_size - tQcQ[0][0] + predicate = cute.make_fragment_like(tQpQ[None, 0, None]) + for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): + for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): + predicate[i, k] = (tQpQ[i, m, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_m + cute.copy( + gmem_tiled_copy_Q, + tQgQ[None, m, None, block], + tQsQ[None, m, None, smem_pipe_write_q if cutlass.const_expr(self.num_stages_Q) > 1 else 0], + pred=predicate, + ) + # We need to clear the sQ smem tiles since we'll use sQt for mma_dK + # We made sure LSE length is padded so we read `kBlockM` elements so that all + # elements in sLSE are filled. Without this we might have uninitialized sLSE values. + for m in cutlass.range_constexpr(cute.size(tLSEsLSE.shape[1])): + if tLSEcLSE[0, m][0] < self.m_block_size: + cute.copy( + gmem_tiled_copy_LSE, + tLSEgLSE[None, m, block], + tLSEsLSE[None, m, smem_pipe_write_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], + ) + + @cute.jit + def load_dO_dPsum( + self, + gmem_tiled_copy_dO: cute.TiledCopy, + gmem_tiled_copy_dPsum: cute.TiledCopy, + tdOgdO: cute.Tensor, + tdOsdO: cute.Tensor, + tdOcdO: cute.Tensor, + t0dOcdO: cute.Tensor, + tdOpdO: cute.Tensor, + tdPsumgdPsum: cute.Tensor, + tdPsumsdPsum: cute.Tensor, + tdPsumcdPsum: cute.Tensor, + block: cutlass.Int32, + smem_pipe_write_q: cutlass.Int32, + seqlen: cutlass.Int32, + ): + for m in cutlass.range_constexpr(cute.size(tdOsdO.shape[1])): + # If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked + if self.is_even_m_smem_do or m < cute.size(tdOsdO.shape[1]) - 1 or tdOcdO[0, m, 0][0] < self.m_block_size: + # Instead of using tdOcdO, we using t0dOcdO and subtract the offset from the limit + # (seqlen - block * kBlockM). This is because the entries of t0dOcdO are known at compile time. + predicate_m = t0dOcdO[0, m, 0][0] < seqlen - block * self.m_block_size - tdOcdO[0][0] + predicate = cute.make_fragment_like(tdOpdO[None, 0, None]) + for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): + for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): + predicate[i, k] = (tdOpdO[i, m, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_m + cute.copy( + gmem_tiled_copy_dO, + tdOgdO[None, m, None, block], + tdOsdO[None, m, None, smem_pipe_write_q if cutlass.const_expr(self.num_stages_dO > 1) else 0], + pred=predicate, + ) + # We need to clear the sQ smem tiles since we'll use sQt for mma_dK + # We made sure LSE length is padded so we read `kBlockM` elements so that all + # elements in sLSE are filled. Without this we might have uninitialized sLSE values. + for m in cutlass.range_constexpr(cute.size(tdPsumgdPsum.shape[1])): + if tdPsumcdPsum[0, m][0] < self.m_block_size: + cute.copy( + gmem_tiled_copy_dPsum, + tdPsumgdPsum[None, m, block], + tdPsumsdPsum[None, m, smem_pipe_write_q if cutlass.const_expr(self.num_stages_dO > 1) else 0], + ) diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py new file mode 100644 index 00000000000..5b1a3acae64 --- /dev/null +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -0,0 +1,463 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_postprocess_kernel.h +# from Cutlass C++ to Cute-DSL. +import math +from typing import Callable, Optional, Type, Literal + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.utils.hopper_helpers as sm90_utils_basic +import cutlass.utils.blackwell_helpers as sm100_utils_basic +from cutlass.cute.nvgpu import cpasync, warp, warpgroup +from cutlass import Float32, const_expr +from cutlass.utils import LayoutEnum + +from flash_attn.cute import utils +from flash_attn.cute import copy_utils +from flash_attn.cute import ampere_helpers as sm80_utils +from flash_attn.cute import hopper_helpers as sm90_utils +from flash_attn.cute.seqlen_info import SeqlenInfoQK +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +from flash_attn.cute.tile_scheduler import ( + ParamsBase, + SingleTileScheduler, + SingleTileVarlenScheduler, + TileSchedulerArguments, +) + + +class FlashAttentionBackwardPostprocess: + def __init__( + self, + dtype: Type[cutlass.Numeric], + head_dim: int, + arch: Literal[80, 90, 100], + tile_m: int = 128, + num_threads: int = 256, + AtomLayoutMdQ: int = 1, + dQ_swapAB: bool = False, + ): + """ + :param head_dim: head dimension + :type head_dim: int + :param tile_m: m block size + :type tile_m: int + """ + self.dtype = dtype + self.tile_m = tile_m + assert arch in [80, 90, 100], ( + "Only Ampere (80), Hopper (90), and Blackwell (100) are supported" + ) + self.arch = arch + # padding head_dim to a multiple of 32 as k_block_size + hdim_multiple_of = 32 + self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + self.check_hdim_oob = head_dim != self.tile_hdim + self.num_threads = num_threads + self.AtomLayoutMdQ = AtomLayoutMdQ + self.dQ_swapAB = dQ_swapAB + + @staticmethod + def can_implement(dtype, head_dim, tile_m, num_threads) -> bool: + """Check if the kernel can be implemented with the given parameters. + + :param dtype: data type + :type dtype: cutlass.Numeric + :param head_dim: head dimension + :type head_dim: int + :param tile_m: m block size + :type tile_m: int + + :return: True if the kernel can be implemented, False otherwise + :rtype: bool + """ + if dtype not in [cutlass.Float16, cutlass.BFloat16]: + return False + if head_dim % 8 != 0: + return False + if num_threads % 32 != 0: + return False + return True + + def _get_tiled_mma(self): + if const_expr(self.arch == 80): + num_mma_warps = self.num_threads // 32 + atom_layout_dQ = ( + (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) + if const_expr(not self.dQ_swapAB) + else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1) + ) + tiled_mma = cute.make_tiled_mma( + warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)), + atom_layout_dQ, + permutation_mnk=(atom_layout_dQ[0] * 16, atom_layout_dQ[1] * 16, 16), + ) + elif const_expr(self.arch == 90): + num_mma_warp_groups = self.num_threads // 128 + atom_layout_dQ = (self.AtomLayoutMdQ, num_mma_warp_groups // self.AtomLayoutMdQ) + tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1]) + tiled_mma = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K, # These don't matter, we only care about the accum + warpgroup.OperandMajorMode.K, + Float32, + atom_layout_mnk=(atom_layout_dQ if not self.dQ_swapAB else atom_layout_dQ[::-1]) + + (1,), + tiler_mn=tiler_mn_dQ if not self.dQ_swapAB else tiler_mn_dQ[::-1], + ) + else: + cta_group = tcgen05.CtaGroup.ONE + tiled_mma = sm100_utils_basic.make_trivial_tiled_mma( + self.dtype, + tcgen05.OperandMajorMode.MN, # dS_major_mode + tcgen05.OperandMajorMode.MN, # Kt_major_mode + Float32, + cta_group, + (self.tile_m, self.tile_hdim), + ) + if const_expr(self.arch in [80, 90]): + assert self.num_threads == tiled_mma.size + return tiled_mma + + def _setup_attributes(self): + # /////////////////////////////////////////////////////////////////////////////// + # GMEM Tiled copy: + # /////////////////////////////////////////////////////////////////////////////// + # Thread layouts for copies + universal_copy_bits = 128 + async_copy_elems_accum = universal_copy_bits // Float32.width + atom_async_copy_accum = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + Float32, + num_bits_per_copy=universal_copy_bits, + ) + # We don't do bound checking for the gmem -> smem load so we just assert here. + assert (self.tile_m * self.tile_hdim // async_copy_elems_accum) % self.num_threads == 0 + self.g2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv( + atom_async_copy_accum, + cute.make_layout(self.num_threads), + cute.make_layout(async_copy_elems_accum), + ) + num_s2r_copy_elems = 1 if const_expr(self.arch == 80) else 4 + if const_expr(self.arch == 80): + self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( + Float32, self.num_threads, num_s2r_copy_elems + ) + self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim) + elif const_expr(self.arch == 90): + num_threads_per_warp_group = 128 + num_mma_warp_groups = self.num_threads // 128 + self.s2r_tiled_copy_dQaccum = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128), + cute.make_layout((num_threads_per_warp_group, num_mma_warp_groups)), # thr_layout + cute.make_layout(128 // Float32.width), # val_layout + ) + self.sdQaccum_layout = cute.make_layout( + (self.tile_m * self.tile_hdim // num_mma_warp_groups, num_mma_warp_groups) + ) + else: + self.dQ_reduce_ncol = 32 + dQaccum_reduce_stage = self.tile_hdim // self.dQ_reduce_ncol + assert self.num_threads == 128 # TODO: currently hard-coded + self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( + Float32, self.num_threads, num_s2r_copy_elems + ) + self.sdQaccum_layout = cute.make_layout( + (self.tile_m * self.tile_hdim // dQaccum_reduce_stage, dQaccum_reduce_stage) + ) + + self.gmem_tiled_copy_dQ = copy_utils.tiled_copy_2d( + self.dtype, self.tile_hdim, self.num_threads + ) + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory layout: dQ + # /////////////////////////////////////////////////////////////////////////////// + # We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs, + # then setting kBlockKSmem to 32 will cause "Static shape_div failure". + # We want to treat it as 64 x 48, so kBlockKSmem should be 16. + mma_shape_n = self.tiled_mma.get_tile_size(1) + if const_expr(self.arch == 80): + sdQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, mma_shape_n) + self.sdQ_layout = cute.tile_to_shape( + sdQ_layout_atom, (self.tile_m, self.tile_hdim), (0, 1) + ) + elif const_expr(self.arch == 90): + self.sdQ_layout = sm90_utils.make_smem_layout( + self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim) + ) + else: + # TODO: this is hard-coded for hdim 128 + self.sdQ_layout = sm100_utils_basic.make_smem_layout_epi( + self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim), 1 + ) + + @cute.jit + def __call__( + self, + mdQaccum: cute.Tensor, + mdQ: cute.Tensor, + scale: cutlass.Float32, + mCuSeqlensQ: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + stream: cuda.CUstream, + ): + # Get the data type and check if it is fp16 or bf16 + if const_expr(mdQ.element_type not in [cutlass.Float16, cutlass.BFloat16]): + raise TypeError("Only Float16 or BFloat16 is supported") + if const_expr(mdQaccum is not None): + if const_expr(mdQaccum.element_type not in [cutlass.Float32]): + raise TypeError("dQaccum tensor must be Float32") + + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + mdQaccum, mdQ = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + for t in (mdQaccum, mdQ) + ] + + self.tiled_mma = self._get_tiled_mma() + self._setup_attributes() + + smem_size = max( + cute.size_in_bytes(cutlass.Float32, self.sdQaccum_layout), + cute.size_in_bytes(self.dtype, self.sdQ_layout), + ) + + if const_expr(mCuSeqlensQ is not None): + TileScheduler = SingleTileVarlenScheduler + num_head = mdQ.shape[1] + num_batch = mCuSeqlensQ.shape[0] - 1 + num_block = cute.ceil_div(mdQ.shape[0], self.tile_m) + else: + TileScheduler = SingleTileScheduler + num_head = mdQ.shape[2] + num_batch = mdQ.shape[0] + num_block = cute.ceil_div(mdQ.shape[1], self.tile_m) + + tile_sched_args = TileSchedulerArguments( + num_block=num_block, + num_head=num_head, + num_batch=num_batch, + num_splits=1, + seqlen_k=0, + headdim=mdQ.shape[2], + headdim_v=0, + total_q=mdQ.shape[0], + tile_shape_mn=(self.tile_m, 1), + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, + ) + + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + + # grid_dim: (m_block, num_head, batch_size) + self.kernel( + mdQaccum, + mdQ, + mCuSeqlensQ, + mSeqUsedQ, + scale, + self.tiled_mma, + self.dQ_swapAB, + self.sdQaccum_layout, + self.sdQ_layout, + self.g2s_tiled_copy_dQaccum, + self.s2r_tiled_copy_dQaccum, + self.gmem_tiled_copy_dQ, + tile_sched_params, + TileScheduler, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + smem=smem_size, + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mdQaccum: cute.Tensor, + mdQ: cute.Tensor, + mCuSeqlensQ: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + scale: cutlass.Float32, + tiled_mma: cute.TiledMma, + dQ_swapAB: cutlass.Constexpr, + sdQaccum_layout: cute.Layout, + sdQ_layout: cute.ComposedLayout, + g2s_tiled_copy_dQaccum: cute.TiledCopy, + s2r_tiled_copy_dQaccum: cute.TiledCopy, + gmem_tiled_copy_dQ: cute.TiledCopy, + tile_sched_params: ParamsBase, + TileScheduler: cutlass.Constexpr[Callable], + ): + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=1024) + sdQaccum_flat = cute.make_tensor(sdQaccum.iterator, cute.make_layout(cute.size(sdQaccum))) + if const_expr(self.arch in [80, 90]): + sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout) + else: + # extra stage dimension + sdQ = cute.make_tensor( + cute.recast_ptr(sdQaccum.iterator, sdQ_layout.inner, dtype=self.dtype), + sdQ_layout.outer, + )[None, None, 0] + sdQt = utils.transpose_view(sdQ) + + # Thread index, block index + tidx, _, _ = cute.arch.thread_idx() + + tile_scheduler = TileScheduler.create(tile_sched_params) + work_tile = tile_scheduler.initial_work_tile_info() + + m_block, head_idx, batch_idx, _ = work_tile.tile_idx + + if work_tile.is_valid_tile: + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # /////////////////////////////////////////////////////////////////////////////// + + seqlen = SeqlenInfoQK.create( + batch_idx, + mdQ.shape[1], + 0, + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=None, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=None, + ) + if const_expr(not seqlen.has_cu_seqlens_q): + mdQ_cur = mdQ[batch_idx, None, head_idx, None] + mdQaccum_cur = mdQaccum[batch_idx, head_idx, None] + head_dim = mdQ.shape[3] + else: + padded_offset_q = seqlen.offset_q + batch_idx * self.tile_m + if cutlass.const_expr(self.arch >= 90): + padded_offset_q = padded_offset_q // self.tile_m * self.tile_m + mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, head_idx, None]) + mdQaccum_cur = cute.domain_offset( + (padded_offset_q * self.tile_hdim,), mdQaccum[head_idx, None] + ) + head_dim = mdQ.shape[2] + + # HACK: Compiler doesn't seem to recognize that padding + # by padded_offset_q * self.tile_hdim keeps alignment + # since statically divisible by 4 + + mdQaccum_cur_ptr = cute.make_ptr( + dtype=mdQaccum_cur.element_type, + value=mdQaccum_cur.iterator.toint(), + mem_space=mdQaccum_cur.iterator.memspace, + assumed_align=mdQaccum.iterator.alignment, + ) + mdQaccum_cur = cute.make_tensor(mdQaccum_cur_ptr, mdQaccum_cur.layout) + + gdQaccum = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (m_block,)) + gdQ = cute.local_tile(mdQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) + + seqlen_q = seqlen.seqlen_q + seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m) + + # Step 1: load dQaccum from gmem to smem + g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_slice(tidx) + tdQgdQaccum = g2s_thr_copy_dQaccum.partition_S(gdQaccum) + tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum_flat) + cute.copy(g2s_tiled_copy_dQaccum, tdQgdQaccum, tdQsdQaccumg2s) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + + # Step 2: load dQ from smem to rmem + s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_slice(tidx) + tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum) + tile_shape = (self.tile_m, self.tile_hdim) + acc = None + tiled_copy_t2r = None + if const_expr(self.arch in [80, 90]): + acc_shape = tiled_mma.partition_shape_C( + tile_shape if const_expr(not dQ_swapAB) else tile_shape[::-1] + ) + acc = cute.make_fragment(acc_shape, cutlass.Float32) + assert cute.size(acc) == cute.size(tdQsdQaccum) + else: + thr_mma = tiled_mma.get_slice(0) # 1-CTA + dQacc_shape = tiled_mma.partition_shape_C((self.tile_m, self.tile_hdim)) + tdQtdQ = tiled_mma.make_fragment_C(dQacc_shape) + tdQcdQ = thr_mma.partition_C( + cute.make_identity_tensor((self.tile_m, self.tile_hdim)) + ) + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol)), Float32 + ) + tiled_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + tdQrdQ_t2r_shape = thr_copy_t2r.partition_D(tdQcdQ).shape + acc = cute.make_fragment(tdQrdQ_t2r_shape, Float32) + tdQrdQaccum = cute.make_tensor(acc.iterator, cute.make_layout(tdQsdQaccum.shape)) + cute.autovec_copy(tdQsdQaccum, tdQrdQaccum) + # Convert tdQrdQaccum from fp32 to fp16/bf16 + rdQ = cute.make_fragment_like(acc, self.dtype) + rdQ.store((acc.load() * scale).to(self.dtype)) + + # Step 3: Copy dQ from register to smem + cute.arch.barrier() # make sure all threads have finished loading dQaccum + if const_expr(self.arch in [80, 90]): + copy_atom_r2s_dQ = utils.get_smem_store_atom( + self.arch, self.dtype, transpose=self.dQ_swapAB + ) + tiled_copy_r2s_dQ = cute.make_tiled_copy_C(copy_atom_r2s_dQ, tiled_mma) + else: + # copy_atom_r2s_dQ = sm100_utils_basic.get_smem_store_op( + # LayoutEnum.ROW_MAJOR, self.dtype, Float32, tiled_copy_t2r, + # ) + # tiled_copy_r2s_dQ = cute.make_tiled_copy_D(copy_atom_r2s_dQ, tiled_copy_t2r) + thr_layout_r2s_dQ = cute.make_layout((self.num_threads, 1)) # 128 threads + val_layout_r2s_dQ = cute.make_layout((1, 128 // self.dtype.width)) + copy_atom_r2s_dQ = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=128, + ) + tiled_copy_r2s_dQ = cute.make_tiled_copy_tv( + copy_atom_r2s_dQ, thr_layout_r2s_dQ, val_layout_r2s_dQ + ) + thr_copy_r2s_dQ = tiled_copy_r2s_dQ.get_slice(tidx) + cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim)) + if const_expr(self.arch in [80, 90]): + taccdQrdQ = thr_copy_r2s_dQ.retile(rdQ) + else: + taccdQcdQ_shape = thr_copy_r2s_dQ.partition_S(cdQ).shape + taccdQrdQ = cute.make_tensor(rdQ.iterator, taccdQcdQ_shape) + taccdQsdQ = thr_copy_r2s_dQ.partition_D(sdQ if const_expr(not self.dQ_swapAB) else sdQt) + cute.copy(thr_copy_r2s_dQ, taccdQrdQ, taccdQsdQ) + + # Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem + cute.arch.barrier() # make sure all smem stores are done + gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_slice(tidx) + tdQgdQ = gmem_thr_copy_dQ.partition_S(gdQ) + tdQsdQ = gmem_thr_copy_dQ.partition_D(sdQ) + tdQrdQ = cute.make_fragment_like(tdQsdQ, self.dtype) + # TODO: check OOB when reading from smem if kBlockM isn't evenly tiled + cute.autovec_copy(tdQsdQ, tdQrdQ) + + # Step 5: Copy dQ from register to gmem + tdQcdQ = gmem_thr_copy_dQ.partition_S(cdQ) + tdQpdQ = utils.predicate_k(tdQcdQ, limit=head_dim) + for rest_m in cutlass.range(cute.size(tdQrdQ.shape[1]), unroll_full=True): + if tdQcdQ[0, rest_m, 0][0] < seqlen_q - m_block * self.tile_m: + cute.copy( + gmem_tiled_copy_dQ, + tdQrdQ[None, rest_m, None], + tdQgdQ[None, rest_m, None], + pred=tdQpdQ[None, rest_m, None], + ) diff --git a/flash_attn/cute/flash_bwd_preprocess.py b/flash_attn/cute/flash_bwd_preprocess.py new file mode 100644 index 00000000000..cd514316f88 --- /dev/null +++ b/flash_attn/cute/flash_bwd_preprocess.py @@ -0,0 +1,365 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_preprocess_kernel.h +# from Cutlass C++ to Cute-DSL. +import math +import operator +from typing import Callable, Type, Optional, Literal + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass import Float32 + +from flash_attn.cute import utils +from flash_attn.cute import copy_utils +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.tile_scheduler import ( + ParamsBase, + SingleTileScheduler, + SingleTileVarlenScheduler, + TileSchedulerArguments, +) + + +class FlashAttentionBackwardPreprocess: + def __init__( + self, + dtype: Type[cutlass.Numeric], + head_dim: int, + arch: Literal[80, 90, 100], + m_block_size: int = 128, + num_threads: int = 128, + ): + """ + All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension + should be a multiple of 8. + + :param head_dim: head dimension + :type head_dim: int + :param m_block_size: m block size + :type m_block_size: int + :param num_threads: number of threads + :type num_threads: int + """ + self.dtype = dtype + self.m_block_size = m_block_size + self.arch = arch + # padding head_dim to a multiple of 32 as k_block_size + hdim_multiple_of = 32 + self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + self.check_hdim_oob = head_dim != self.head_dim_padded + self.num_threads = num_threads + + @staticmethod + def can_implement(dtype, head_dim, m_block_size, num_threads) -> bool: + """Check if the kernel can be implemented with the given parameters. + + :param dtype: data type + :type dtype: cutlass.Numeric + :param head_dim: head dimension + :type head_dim: int + :param m_block_size: m block size + :type m_block_size: int + :param num_threads: number of threads + :type num_threads: int + + :return: True if the kernel can be implemented, False otherwise + :rtype: bool + """ + if dtype not in [cutlass.Float16, cutlass.BFloat16]: + return False + if head_dim % 8 != 0: + return False + if num_threads % 32 != 0: + return False + if num_threads < m_block_size: # For multiplying lse with log2 + return False + return True + + def _setup_attributes(self): + # /////////////////////////////////////////////////////////////////////////////// + # GMEM Tiled copy: + # /////////////////////////////////////////////////////////////////////////////// + # Thread layouts for copies + # We want kBlockKGmem to be a power of 2 so that when we do the summing, + # it's just between threads in the same warp + gmem_k_block_size = ( + 128 + if self.head_dim_padded % 128 == 0 + else ( + 64 + if self.head_dim_padded % 64 == 0 + else (32 if self.head_dim_padded % 32 == 0 else 16) + ) + ) + self.gmem_tiled_copy_O = copy_utils.tiled_copy_2d( + self.dtype, gmem_k_block_size, self.num_threads + ) + universal_copy_bits = 128 + num_copy_elems_dQaccum = universal_copy_bits // Float32.width + assert ( + self.m_block_size * self.head_dim_padded // num_copy_elems_dQaccum + ) % self.num_threads == 0 + self.gmem_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( + Float32, self.num_threads, num_copy_elems_dQaccum + ) + + @cute.jit + def __call__( + self, + mO: cute.Tensor, + mdO: cute.Tensor, + mdPsum: cute.Tensor, + mLSE: Optional[cute.Tensor], + mLSElog2: Optional[cute.Tensor], + mdQaccum: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + stream: cuda.CUstream, + ): + # Get the data type and check if it is fp16 or bf16 + if cutlass.const_expr(not (mO.element_type == mdO.element_type)): + raise TypeError("All tensors must have the same data type") + if cutlass.const_expr(mO.element_type not in [cutlass.Float16, cutlass.BFloat16]): + raise TypeError("Only Float16 or BFloat16 is supported") + if cutlass.const_expr(mdPsum.element_type not in [Float32]): + raise TypeError("dPsum tensor must be Float32") + if cutlass.const_expr(mdQaccum is not None): + if cutlass.const_expr(mdQaccum.element_type not in [Float32]): + raise TypeError("dQaccum tensor must be Float32") + if cutlass.const_expr(mLSE is not None): + assert mLSElog2 is not None, "If mLSE is provided, mLSElog2 must also be provided" + if cutlass.const_expr(mLSE.element_type not in [Float32]): + raise TypeError("LSE tensor must be Float32") + if cutlass.const_expr(mLSElog2.element_type not in [Float32]): + raise TypeError("LSElog2 tensor must be Float32") + + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + mO, mdO, mdQaccum = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + if t is not None + else None + for t in (mO, mdO, mdQaccum) + ] + + self._setup_attributes() + + if cutlass.const_expr(mCuSeqlensQ is not None): + TileScheduler = SingleTileVarlenScheduler + num_head = mO.shape[1] + num_batch = mCuSeqlensQ.shape[0] - 1 + else: + TileScheduler = SingleTileScheduler + num_head = mO.shape[2] + num_batch = mO.shape[0] + + tile_sched_args = TileSchedulerArguments( + num_block=cute.ceil_div(mO.shape[1], self.m_block_size), + num_head=num_head, + num_batch=num_batch, + num_splits=1, + seqlen_k=0, + headdim=0, + headdim_v=mO.shape[2], + total_q=mO.shape[0], + tile_shape_mn=(self.m_block_size, 1), + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, + ) + + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + + self.kernel( + mO, + mdO, + mdPsum, + mLSE, + mLSElog2, + mdQaccum, + mCuSeqlensQ, + mSeqUsedQ, + self.gmem_tiled_copy_O, + self.gmem_tiled_copy_dQaccum, + tile_sched_params, + TileScheduler, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mO: cute.Tensor, + mdO: cute.Tensor, + mdPsum: cute.Tensor, + mLSE: Optional[cute.Tensor], + mLSElog2: Optional[cute.Tensor], + mdQaccum: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + gmem_tiled_copy_O: cute.TiledCopy, + gmem_tiled_copy_dQaccum: cute.TiledCopy, + tile_sched_params: ParamsBase, + TileScheduler: cutlass.Constexpr[Callable], + ): + # Thread index, block index + tidx, _, _ = cute.arch.thread_idx() + + tile_scheduler = TileScheduler.create(tile_sched_params) + work_tile = tile_scheduler.initial_work_tile_info() + m_block, head_idx, batch_idx, _ = work_tile.tile_idx + + if work_tile.is_valid_tile: + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # /////////////////////////////////////////////////////////////////////////////// + seqlen = SeqlenInfoQK.create( + batch_idx, + mO.shape[1], + 0, + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=None, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=None, + ) + + if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + mO_cur = mO[batch_idx, None, head_idx, None] + mdO_cur = mdO[batch_idx, None, head_idx, None] + mdPsum_cur = mdPsum[batch_idx, head_idx, None] + headdim_v = mO.shape[3] + else: + mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, head_idx, None]) + mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None]) + + padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size + if cutlass.const_expr(self.arch >= 90): + padded_offset_q = padded_offset_q // self.m_block_size * self.m_block_size + mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[head_idx, None]) + headdim_v = mO.shape[2] + + blkOdO_shape = (self.m_block_size, self.head_dim_padded) + # (m_block_size, head_dim) + gO = cute.local_tile(mO_cur, blkOdO_shape, (m_block, 0)) + gdO = cute.local_tile(mdO_cur, blkOdO_shape, (m_block, 0)) + + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + # (CPY_Atom, CPY_M, CPY_K) + tOgO = gmem_thr_copy_O.partition_S(gO) + tOgdO = gmem_thr_copy_O.partition_S(gdO) + + # /////////////////////////////////////////////////////////////////////////////// + # Predicate: Mark indices that need to copy when problem_shape isn't a multiple + # of tile_shape + # /////////////////////////////////////////////////////////////////////////////// + # Construct identity layout for KV + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tOcO = gmem_thr_copy_O.partition_S(cO) + t0OcO = gmem_thr_copy_O.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=headdim_v) + tOpdO = utils.predicate_k(tOcO, limit=headdim_v) + + seqlen_q = seqlen.seqlen_q + seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size) + + if cutlass.const_expr(mLSE is not None): + if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + mLSE_cur = mLSE[batch_idx, head_idx, None] + else: + mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[head_idx, None]) + + gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block,)) + lse = Float32.inf + if tidx < seqlen_q - m_block * self.m_block_size: + lse = gLSE[tidx] + + tOrO = cute.make_fragment_like(tOgO) + tOrdO = cute.make_fragment_like(tOgdO) + assert cute.size(tOgO, mode=[0]) == cute.size(tOgdO, mode=[0]) + assert cute.size(tOgO, mode=[1]) == cute.size(tOgdO, mode=[1]) + assert cute.size(tOgO, mode=[2]) == cute.size(tOgdO, mode=[2]) + for m in cutlass.range(cute.size(tOrO.shape[1]), unroll_full=True): + # Instead of using tOcO, we using t0OcO and subtract the offset from the limit + # (seqlen_q - m_block * kBlockM). This is because the entries of t0OcO are known at compile time. + if t0OcO[0, m, 0][0] < seqlen_q - m_block * self.m_block_size - tOcO[0][0]: + cute.copy( + gmem_thr_copy_O, + tOgO[None, m, None], + tOrO[None, m, None], + pred=tOpO[None, m, None] + if cutlass.const_expr(self.check_hdim_oob) + else None, + ) + cute.copy( + gmem_thr_copy_O, + tOgdO[None, m, None], + tOrdO[None, m, None], + pred=tOpdO[None, m, None] + if cutlass.const_expr(self.check_hdim_oob) + else None, + ) + # Sum across the "k" dimension + dpsum = (tOrO.load().to(Float32) * tOrdO.load().to(Float32)).reduce( + cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, 1) + ) + threads_per_row = gmem_tiled_copy_O.layout_src_tv_tiled[0].shape[0] + assert cute.arch.WARP_SIZE % threads_per_row == 0 + dpsum = utils.warp_reduce(dpsum, operator.add, width=threads_per_row) + dP_sum = cute.make_fragment(cute.size(tOrO, mode=[1]), Float32) + dP_sum.store(dpsum) + + # Write dPsum from rmem -> gmem + gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (m_block,)) + # Only the thread corresponding to column 0 writes out the dPsum to gmem + if tOcO[0, 0, 0][1] == 0: + for m in cutlass.range(cute.size(dP_sum), unroll_full=True): + row = tOcO[0, m, 0][0] + gdPsum[row] = dP_sum[m] if row < seqlen_q - m_block * self.m_block_size else 0.0 + + # Clear dQaccum + if cutlass.const_expr(mdQaccum is not None): + if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + mdQaccum_cur = mdQaccum[batch_idx, head_idx, None] + else: + mdQaccum_cur = cute.domain_offset( + (padded_offset_q * self.head_dim_padded,), mdQaccum[head_idx, None] + ) + + # HACK: Compiler doesn't seem to recognize that padding + # by padded_offset_q * self.head_dim_padded keeps alignment + # since statically divisible by 4 + + mdQaccum_cur_ptr = cute.make_ptr( + dtype=mdQaccum_cur.element_type, + value=mdQaccum_cur.iterator.toint(), + mem_space=mdQaccum_cur.iterator.memspace, + assumed_align=mdQaccum.iterator.alignment, + ) + mdQaccum_cur = cute.make_tensor(mdQaccum_cur_ptr, mdQaccum_cur.layout) + + blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,) + gdQaccum = cute.local_tile(mdQaccum_cur, blkdQaccum_shape, (m_block,)) + gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx) + tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) + zero = cute.make_fragment_like(tdQgdQaccum) + zero.fill(0.0) + cute.copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum) + + if cutlass.const_expr(mLSE is not None): + if cutlass.const_expr(not seqlen.has_cu_seqlens_q): + mLSElog2_cur = mLSElog2[batch_idx, head_idx, None] + else: + mLSElog2_cur = cute.domain_offset((padded_offset_q,), mLSElog2[head_idx, None]) + + gLSElog2 = cute.local_tile(mLSElog2_cur, (self.m_block_size,), (m_block,)) + LOG2_E = math.log2(math.e) + if tidx < seqlen_q_rounded - m_block * self.m_block_size: + gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0 diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py new file mode 100644 index 00000000000..0b0488963ba --- /dev/null +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -0,0 +1,2950 @@ +# Copyright (c) 2025, Ted Zadouri, Markus Hoehnerbach, Jay Shah, Tri Dao. +import math +from typing import Callable, Optional +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute import FastDivmodDivisor +from cutlass import Float32, Int32, const_expr +from cutlass.utils import LayoutEnum +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils_basic +from cutlass.pipeline import PipelineAsync, PipelineConsumer + +from flash_attn.cute import utils +from flash_attn.cute import copy_utils +from flash_attn.cute import pipeline +from flash_attn.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.tile_scheduler import ( + TileSchedulerArguments, + SingleTileScheduler, + SingleTileLPTBwdScheduler, # noqa + SingleTileVarlenScheduler, + ParamsBase, +) + +from flash_attn.cute import barrier +from flash_attn.cute.named_barrier import NamedBarrierBwdSm100 +from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( + get_total_q_block_count_bwd, + get_block_sparse_iteration_info_bwd, + get_m_block_from_iter_bwd, + produce_block_sparse_q_loads_bwd_sm100, +) + + +class FlashAttentionBackwardSm100: + arch = 100 + + def __init__( + self, + head_dim: int, + head_dim_v: Optional[int] = None, + is_causal: bool = False, + is_local: bool = False, + qhead_per_kvhead: cutlass.Constexpr[int] = 1, + tile_m: int = 128, + tile_n: int = 128, + is_persistent: bool = False, + deterministic: bool = False, + cluster_size: int = 1, + score_mod: cutlass.Constexpr | None = None, + score_mod_bwd: cutlass.Constexpr | None = None, + mask_mod: cutlass.Constexpr | None = None, + has_aux_tensors: cutlass.Constexpr = False, + subtile_factor: cutlass.Constexpr[int] = 1, + ): + # padding head_dim to a multiple of 16 as k_block_size + hdim_multiple_of = 16 + self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + head_dim_v = head_dim_v if head_dim_v is not None else head_dim + self.same_hdim_kv = head_dim == head_dim_v + assert head_dim == head_dim_v, "head_dim and head_dim_v must be the same for now" + self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) + assert self.tile_hdim == self.tile_hdimv, ( + "tile_hdim and tile_hdimv must be the same for now" + ) + self.check_hdim_oob = head_dim != self.tile_hdim + self.check_hdim_v_oob = head_dim_v != self.tile_hdimv + + self.tile_m = tile_m + self.tile_n = tile_n + + # CTA tiler + self.cta_tiler = (tile_n, tile_m, self.tile_hdim) + # S = K @ Q.T + self.mma_tiler_kq = (tile_n, tile_m, self.tile_hdim) + # dP = V @ dO.T + self.mma_tiler_vdo = (tile_n, tile_m, self.tile_hdimv) + # dV = P.T @ dO + self.mma_tiler_pdo = (tile_n, self.tile_hdimv, tile_m) + # dK = dS.T @ Q (N, M) (M, D) + self.mma_tiler_dsq = (tile_n, self.tile_hdimv, tile_m) + # dQ = dS @ K + self.mma_tiler_dsk = (tile_m, self.tile_hdimv, tile_n) + + self.acc_dtype = Float32 + + assert cluster_size in (1, 2), "Only cluster_size=1 or 2 is supported" + self.cluster_shape_mn = (cluster_size, 1) + self.is_persistent = is_persistent + self.is_causal = is_causal + self.is_local = is_local + self.qhead_per_kvhead = qhead_per_kvhead + self.pack_gqa = False + self.deterministic = deterministic + + # Score mod and mask mod support + self.score_mod = score_mod + self.score_mod_bwd = score_mod_bwd + self.mask_mod = mask_mod + self.has_aux_tensors = has_aux_tensors + self.subtile_factor = subtile_factor + # For score_mod, use vec_size=1 (like forward) to handle per-element indices + if cutlass.const_expr(has_aux_tensors): + self.vec_size: cutlass.Constexpr = 1 + else: + self.vec_size: cutlass.Constexpr = 4 + self.qk_acc_dtype = Float32 + + # Speed optimizations, does not affect correctness + self.shuffle_LSE = False + self.shuffle_dPsum = False + self.use_smem_dS_for_mma_dK = self.deterministic and self.is_causal + + self.reduce_warp_ids = (0, 1, 2, 3) + self.compute_warp_ids = (4, 5, 6, 7, 8, 9, 10, 11) + self.mma_warp_id = 12 + self.load_warp_id = 13 + self.epi_warp_id = 14 + self.empty_warp_id = 15 + + # 16 warps -> 512 threads + self.threads_per_cta = cute.arch.WARP_SIZE * len( + ( + *self.reduce_warp_ids, + *self.compute_warp_ids, + self.mma_warp_id, + self.load_warp_id, + self.epi_warp_id, + self.empty_warp_id, + ) + ) + + # NamedBarrier + self.compute_sync_barrier = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierBwdSm100.Compute), + num_threads=len(self.compute_warp_ids) * cute.arch.WARP_SIZE, + ) + # self.epilogue_sync_barrier = pipeline.NamedBarrier( + # barrier_id=2, + # num_threads=self.num_compute_warps * self.threads_per_warp, + # ) + self.reduce_sync_barrier = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), + num_threads=len(self.reduce_warp_ids) * cute.arch.WARP_SIZE, + ) + + # TMEM setup + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + # self.tmem_dK_offset = 0 + # self.tmem_dV_offset = self.tmem_dK_offset + self.tile_hdim + # self.tmem_dQ_offset = self.tmem_dV_offset + self.tile_hdimv + # self.tmem_dP_offset = self.tmem_dQ_offset # overlap with dQ + # self.tmem_S_offset = self.tmem_dQ_offset + max(self.tile_m, self.tile_hdim) + # self.tmem_P_offset = self.tmem_S_offset # overlap with S + # self.tmem_total = self.tmem_S_offset + self.tile_n + # assert self.tmem_total <= self.tmem_alloc_cols + + self.tmem_S_offset = 0 + self.tmem_P_offset = 0 # overlap with S + self.tmem_dV_offset = self.tmem_S_offset + self.tile_n + self.tmem_dP_offset = self.tmem_dV_offset + self.tile_hdimv + self.tmem_dQ_offset = self.tmem_dP_offset # overlap with dP + self.tmem_dK_offset = self.tmem_dP_offset + self.tile_m + self.tmem_dS_offset = self.tmem_dP_offset # overlap with dP + + if (not is_causal and not is_local) or deterministic: + self.num_regs_reduce = 152 + self.num_regs_compute = 136 + else: + self.num_regs_reduce = 136 + self.num_regs_compute = 144 + self.num_regs_other = 96 - 8 + self.num_regs_empty = 24 + assert self.num_regs_reduce + self.num_regs_compute * 2 + self.num_regs_other <= 512 + + self.buffer_align_bytes = 1024 + + def _setup_attributes(self): + self.Q_stage = 2 + self.dO_stage = 1 + # LSE_stage = Q_stage and dPsum_stage = dO_stage + # self.sdKVaccum_stage = 2 + # number of tma reduce adds per dQacc mma + self.dQ_reduce_ncol = 32 + self.sdQaccum_stage = 64 // self.dQ_reduce_ncol + assert self.tile_hdim % self.dQ_reduce_ncol == 0 + self.dQaccum_reduce_stage = self.tile_hdim // self.dQ_reduce_ncol + self.cluster_reduce_dQ = False and cute.size(self.cluster_shape_mn) > 1 + # number of tma reduce adds for dKacc and dVacc epilogue + self.dK_reduce_ncol = 32 + + def _get_tiled_mma(self): + cta_group = tcgen05.CtaGroup.ONE + # S = K @ Q.T + tiled_mma_S = sm100_utils_basic.make_trivial_tiled_mma( + self.q_dtype, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.K, + self.acc_dtype, + cta_group, + self.mma_tiler_kq[:2], + ) + # dP = V @ dO.T + tiled_mma_dP = sm100_utils_basic.make_trivial_tiled_mma( + self.do_dtype, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.K, + self.acc_dtype, + cta_group, + self.mma_tiler_vdo[:2], + ) + # dV += P @ dO --> (K, MN) major + tiled_mma_dV = sm100_utils_basic.make_trivial_tiled_mma( + self.do_dtype, + tcgen05.OperandMajorMode.K, # P_major_mode + tcgen05.OperandMajorMode.MN, # dO_major_mode + self.acc_dtype, + cta_group, + self.mma_tiler_pdo[:2], + a_source=tcgen05.OperandSource.TMEM, + ) + # dK += dS.T @ Q + if const_expr(self.use_smem_dS_for_mma_dK): + mma_dK_a_src = tcgen05.OperandSource.SMEM + else: + mma_dK_a_src = tcgen05.OperandSource.TMEM + tiled_mma_dK = sm100_utils_basic.make_trivial_tiled_mma( + self.do_dtype, + tcgen05.OperandMajorMode.K, # dS_major_mode + tcgen05.OperandMajorMode.MN, # Q_major_mode + self.acc_dtype, + cta_group, + self.mma_tiler_dsq[:2], + a_source=mma_dK_a_src, + ) + # dQ = dS @ K + tiled_mma_dQ = sm100_utils_basic.make_trivial_tiled_mma( + self.k_dtype, + tcgen05.OperandMajorMode.MN, # dS_major_mode + tcgen05.OperandMajorMode.MN, # Kt_major_mode + self.acc_dtype, + cta_group, + self.mma_tiler_dsk[:2], + ) + return tiled_mma_S, tiled_mma_dP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ + + def _setup_smem_layout(self): + # S = K @ Q.T + sK_layout = sm100_utils_basic.make_smem_layout_a( + self.tiled_mma_S, + self.mma_tiler_kq, + self.k_dtype, + 1, + ) + self.sK_layout = cute.slice_(sK_layout, (None, None, None, 0)) + self.sQ_layout = sm100_utils_basic.make_smem_layout_b( + self.tiled_mma_S, + self.mma_tiler_kq, + self.q_dtype, + self.Q_stage, + ) + # dP = V @ dO.T + sV_layout = sm100_utils_basic.make_smem_layout_a( + self.tiled_mma_dP, + self.mma_tiler_vdo, + self.v_dtype, + 1, + ) + self.sV_layout = cute.slice_(sV_layout, (None, None, None, 0)) + self.sdOt_layout = sm100_utils_basic.make_smem_layout_b( + self.tiled_mma_dP, + self.mma_tiler_vdo, + self.do_dtype, + self.dO_stage, + ) + # dV += P @ dO + tP_layout = sm100_utils_basic.make_smem_layout_a( + self.tiled_mma_dV, + self.mma_tiler_pdo, + self.do_dtype, + 1, + ) + self.tP_layout = cute.slice_(tP_layout, (None, None, None, 0)) + self.sdO_layout = sm100_utils_basic.make_smem_layout_b( + self.tiled_mma_dV, + self.mma_tiler_pdo, + self.do_dtype, + self.dO_stage, + ) + # dK += dS.T @ Q + sdSt_layout = sm100_utils_basic.make_smem_layout_a( + self.tiled_mma_dK, + self.mma_tiler_dsq, + self.ds_dtype, + 1, + ) + self.sdSt_layout = cute.slice_(sdSt_layout, (None, None, None, 0)) + tdS_layout = sm100_utils_basic.make_smem_layout_a( + self.tiled_mma_dK, + self.mma_tiler_dsq, + self.ds_dtype, + 1, + ) + self.tdS_layout = cute.slice_(tdS_layout, (None, None, None, 0)) + self.sQt_layout = sm100_utils_basic.make_smem_layout_b( + self.tiled_mma_dK, + self.mma_tiler_dsq, + self.q_dtype, + self.Q_stage, + ) + # dQ = dS @ K + sdS_layout = sm100_utils_basic.make_smem_layout_a( + self.tiled_mma_dQ, + self.mma_tiler_dsk, + self.ds_dtype, + 1, + ) + self.sdS_layout = cute.slice_(sdS_layout, (None, None, None, 0)) + sKt_layout = sm100_utils_basic.make_smem_layout_b( + self.tiled_mma_dQ, + self.mma_tiler_dsk, + self.k_dtype, + 1, + ) + self.sKt_layout = cute.slice_(sKt_layout, (None, None, None, 0)) + self.sdQaccum_layout = cute.make_layout( + (self.tile_m * self.dQ_reduce_ncol, self.sdQaccum_stage) + ) + self.sLSE_layout = cute.make_layout( + shape=(self.tile_m, self.Q_stage), + stride=(1, cute.round_up(self.tile_m, 64)), + ) + self.sdPsum_layout = cute.make_layout( + shape=(self.tile_m, self.dO_stage), + stride=(1, cute.round_up(self.tile_m, 64)), + ) + self.sdKV_epi_tile = ( + self.tile_n, + min(128 // (self.dk_dtype.width // 8), self.tile_hdim // 2), # 64 or 32 + ) # subtiles mma_tiler_dsq[:2] = mma_tiler_pdo[:2] + # headdim_64 gets 1 stage + self.num_epi_stages = max(1, (self.tile_hdim // 2) // self.sdKV_epi_tile[1]) + self.sdKV_flat_epi_tile = self.tile_n * (self.tile_hdim // 2) // self.num_epi_stages + # TODO: dK and dV could have different shapes + if const_expr(not self.dKV_postprocess): + self.sdKV_layout = sm100_utils_basic.make_smem_layout_epi( + self.dk_dtype, + LayoutEnum.ROW_MAJOR, + self.sdKV_epi_tile, + 2, # num compute wgs + ) + else: + self.sdKV_layout = cute.make_layout((self.tile_n * self.dK_reduce_ncol, 2)) + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mdO: cute.Tensor, + mLSE: cute.Tensor, + mdPsum: cute.Tensor, + mdQaccum: cute.Tensor, + mdK: cute.Tensor, + mdV: cute.Tensor, + softmax_scale: Float32, + stream: cuda.CUstream, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + softcap: Float32 | float | None = None, + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, + mdQ_semaphore: Optional[cute.Tensor] = None, + mdK_semaphore: Optional[cute.Tensor] = None, + mdV_semaphore: Optional[cute.Tensor] = None, + aux_tensors: Optional[list] = None, + # Block-sparse tensors (Q direction - for iterating m_blocks per n_block): + blocksparse_tensors: Optional[BlockSparseTensors] = None, + ): + self.q_dtype = mQ.element_type + self.k_dtype = mK.element_type + self.v_dtype = mV.element_type + self.do_dtype = mdO.element_type + self.lse_dtype = mLSE.element_type + self.dpsum_dtype = mdPsum.element_type + self.dqaccum_dtype = mdQaccum.element_type + self.dk_dtype = mdK.element_type + self.dv_dtype = mdV.element_type + self.ds_dtype = self.q_dtype + + self.is_varlen_k = mCuSeqlensK is not None or mSeqUsedK is not None + self.is_varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None + self.use_tma_store = not (self.qhead_per_kvhead == 1 and mCuSeqlensK is not None) + self.dKV_postprocess = self.qhead_per_kvhead > 1 + + if const_expr(self.dKV_postprocess): + assert self.dk_dtype.width == 32, "Must accumulate dK in float precision for GQA" + assert self.dv_dtype.width == 32, "Must accumulate dV in float precision for GQA" + + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + ( + mdQaccum, + mdK, + mdV, + ) = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + if t is not None + else None + for t in ( + mdQaccum, + mdK, + mdV, + ) + ] + + # (b, s, n, h) --> (s, h, n, b) or (t, n, h) -> (t, h, n) + QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + mQ, mdO = [utils.select(t, mode=QO_layout_transpose) for t in (mQ, mdO)] + + KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] + mK, mV = [utils.select(t, mode=KV_layout_transpose) for t in (mK, mV)] + + # (b, n, s) --> (s, n, b) or (n, t) --> (t, n) + LSE_dPsum_dQaccum_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + mLSE, mdPsum, mdQaccum = [ + utils.select(t, mode=LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum) + ] + + if const_expr(not self.dKV_postprocess): + layout_dKV_transpose = KV_layout_transpose + else: + layout_dKV_transpose = [2, 1, 0] if const_expr(mCuSeqlensK is None) else [1, 0] + mdK, mdV = [utils.select(t, mode=layout_dKV_transpose) for t in (mdK, mdV)] + # (s, h, n, b) --> (h, s, n, b) or (t, h, n) -> (h, t, b) + dO_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensQ is None) else [1, 0, 2] + mdO = utils.select(mdO, mode=dO_transpose) + + # (b, n, block, stage) -> (block, stage, n, b) + semaphore_transpose = [2, 3, 1, 0] + if const_expr(self.deterministic): + assert mdQ_semaphore is not None + mdQ_semaphore = utils.select(mdQ_semaphore, mode=semaphore_transpose) + + if const_expr(self.deterministic and self.qhead_per_kvhead > 1): + assert mdK_semaphore is not None + assert mdV_semaphore is not None + mdK_semaphore, mdV_semaphore = [ + utils.select(t, mode=semaphore_transpose) for t in (mdK_semaphore, mdV_semaphore) + ] + else: + mdK_semaphore = None + mdV_semaphore = None + + self._setup_attributes() + ( + self.tiled_mma_S, + self.tiled_mma_dP, + self.tiled_mma_dK, + self.tiled_mma_dV, + self.tiled_mma_dQ, + ) = self._get_tiled_mma() + self._setup_smem_layout() + + cta_group = tcgen05.CtaGroup.ONE + + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (self.tiled_mma_S.thr_id.shape,), + ) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_q_do_mcast = self.num_mcast_ctas_b > 1 + + if const_expr(not self.dKV_postprocess): + self.mdK_layout_enum = LayoutEnum.from_tensor(mdK) + self.mdV_layout_enum = LayoutEnum.from_tensor(mdV) + dK_major_mode = self.mdK_layout_enum.mma_major_mode() + dV_major_mode = self.mdV_layout_enum.mma_major_mode() + if const_expr(dK_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of mdK is wrong") + if const_expr(dV_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of mdV is wrong") + + if const_expr(self.use_tma_store and not self.dKV_postprocess): + tma_copy_op_dKV = cpasync.CopyBulkTensorTileS2GOp() + tma_atom_dK, mdK_tma_tensor = cpasync.make_tiled_tma_atom( + tma_copy_op_dKV, + mdK, + cute.select(self.sdKV_layout, mode=[0, 1]), + self.sdKV_epi_tile, + 1, # no mcast + ) + tma_atom_dV, mdV_tma_tensor = cpasync.make_tiled_tma_atom( + tma_copy_op_dKV, + mdV, + cute.select(self.sdKV_layout, mode=[0, 1]), + self.sdKV_epi_tile, + 1, # no mcast + ) + else: + mdV_tma_tensor = mdV + mdK_tma_tensor = mdK + tma_atom_dV = None + tma_atom_dK = None + + if const_expr(not self.dKV_postprocess): + thr_layout_r2s_dKV = cute.make_ordered_layout((128, 1), order=(1, 0)) # 128 threads + val_layout_r2s_dKV = cute.make_ordered_layout( + (1, 128 // self.dk_dtype.width), order=(1, 0) + ) # 4 or 8 vals for 16 byte store + copy_atom_r2s_dKV = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dk_dtype, + num_bits_per_copy=128, + ) + tiled_copy_r2s_dKV = cute.make_tiled_copy_tv( + copy_atom_r2s_dKV, thr_layout_r2s_dKV, val_layout_r2s_dKV + ) + else: + tiled_copy_r2s_dKV = copy_utils.tiled_copy_1d( + Float32, 128, num_copy_elems=128 // Float32.width + ) + + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + tma_load_op_multicast = cpasync.CopyBulkTensorTileG2SMulticastOp(cta_group) + + # S.T = K @ Q.T + tma_atom_K, tma_tensor_K = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + mK, + cute.select(self.sK_layout, mode=[0, 1, 2]), + self.mma_tiler_kq, + self.tiled_mma_S, + self.cluster_layout_vmnk.shape, + ) + Q_tma_op = sm100_utils_basic.cluster_shape_to_tma_atom_B( + self.cluster_shape_mnk, self.tiled_mma_S.thr_id + ) + tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tiled_tma_atom_B( + # tma_load_op if const_expr(self.cluster_shape_mnk[0] == 1) else tma_load_op_multicast, + Q_tma_op, + mQ, + cute.select(self.sQ_layout, mode=[0, 1, 2]), + self.mma_tiler_kq, + self.tiled_mma_S, + self.cluster_layout_vmnk.shape, + ) + # dP.T = V @ dO.T + tma_atom_V, tma_tensor_V = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + mV, + cute.select(self.sV_layout, mode=[0, 1, 2]), + self.mma_tiler_vdo, + self.tiled_mma_dP, + self.cluster_layout_vmnk.shape, + ) + dO_tma_op = sm100_utils_basic.cluster_shape_to_tma_atom_B( + self.cluster_shape_mnk, self.tiled_mma_dV.thr_id + ) + tma_atom_dO, tma_tensor_dO = cute.nvgpu.make_tiled_tma_atom_B( + # tma_load_op if const_expr(self.cluster_shape_mnk[0] == 1) else tma_load_op_multicast, + dO_tma_op, + mdO, + cute.select(self.sdO_layout, mode=[0, 1, 2]), + self.mma_tiler_pdo, + self.tiled_mma_dV, + self.cluster_layout_vmnk.shape, + ) + + self.tma_copy_bytes = { + name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2])) + for name, mX, layout in [ + ("Q", mQ, self.sQ_layout), + ("K", mK, self.sK_layout), + ("V", mV, self.sV_layout), + ("dO", mdO, self.sdO_layout), + ] + } + self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8 + self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8 + self.tma_copy_bytes["dQ"] = self.tile_m * self.dQ_reduce_ncol * Float32.width // 8 + self.tma_copy_bytes["dKacc"] = self.tile_n * self.dK_reduce_ncol * Float32.width // 8 + + # TileScheduler = SingleTileScheduler + if const_expr(self.is_varlen_k): + TileScheduler = SingleTileVarlenScheduler + elif const_expr(self.deterministic): + TileScheduler = SingleTileLPTBwdScheduler + else: + TileScheduler = SingleTileScheduler + # reads n_blocks right-to-left + self.spt = (self.is_causal or self.is_local) and self.deterministic + tile_sched_args = TileSchedulerArguments( + cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), # num_blocks + cute.size(mQ.shape[2]), # num_heads = num_query_heads + cute.size(mK.shape[3]) + if const_expr(mCuSeqlensK is None) + else cute.size(mCuSeqlensK.shape[0] - 1), # num_batches + 1, # num_splits + cute.size(mQ.shape[0]), # pass seqlen_q or total_q for seqlen_k + mQ.shape[1], # headdim + mV.shape[1], # headdim_v + total_q=cute.size(mK.shape[0]) # pass total_k for total_q + if const_expr(mCuSeqlensK is not None) + else cute.size(mK.shape[0]) * cute.size(mK.shape[3]), + tile_shape_mn=self.cta_tiler[:2], # (tile_n, tile_m) + cluster_shape_mn=self.cluster_shape_mnk[:2], + mCuSeqlensQ=mCuSeqlensK, + mSeqUsedQ=mSeqUsedK, + qhead_per_kvhead_packgqa=1, # pack_gqa disabled for bwd + element_size=self.k_dtype.width // 8, + is_persistent=self.is_persistent, # persistent mode not tested + lpt=self.spt, + head_swizzle=self.deterministic, + ) + + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + self.tile_scheduler_cls = TileScheduler + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + # cute.printf("grid_dim = {}", grid_dim) + + # Compute allocation sizes for shared buffers that are reused + # sQ is reused for sdK, sdO is reused for sdV + sQ_alloc_bytes = max( + cute.size_in_bytes(self.q_dtype, self.sQ_layout), + cute.size_in_bytes(self.dk_dtype, self.sdKV_layout), + ) + sdO_alloc_bytes = max( + cute.size_in_bytes(self.dv_dtype, self.sdKV_layout), + cute.size_in_bytes(self.do_dtype, self.sdO_layout), + ) + # Sanity check that layouts fit in allocation + sdV_bytes = cute.size_in_bytes(self.dv_dtype, self.sdKV_layout) + sdK_bytes = cute.size_in_bytes(self.dk_dtype, self.sdKV_layout) + assert sdV_bytes <= sdO_alloc_bytes, "sdV doesn't fit in sdO storage allocation" + assert sdK_bytes <= sQ_alloc_bytes, "sdK doesn't fit in sQ storage allocation" + + @cute.struct + class SharedStorage: + Q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage] + dO_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage] + LSE_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage] + dPsum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage] + S_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1] + dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1] + dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1] + dKV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 2] + dQ_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] + dQ_cluster_full_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.dQaccum_reduce_stage // 2 + ] + dQ_cluster_empty_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.dQaccum_reduce_stage // 2 + ] + tmem_holding_buf: Int32 + tmem_dealloc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1] + + # Smem tensors + + # sQ is reused for sdK which in the non-MHA case needs float32 + sQ: cute.struct.Align[ + cute.struct.MemRange[cute.Uint8, sQ_alloc_bytes], + self.buffer_align_bytes, + ] + sK: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(self.sK_layout)], + self.buffer_align_bytes, + ] + sV: cute.struct.Align[ + cute.struct.MemRange[self.v_dtype, cute.cosize(self.sV_layout)], + self.buffer_align_bytes, + ] + # sdO is reused for sdV which in the non-MHA case needs float32 + sdO: cute.struct.Align[ + cute.struct.MemRange[cute.Uint8, sdO_alloc_bytes], + self.buffer_align_bytes, + ] + sdS: cute.struct.Align[ + cute.struct.MemRange[self.ds_dtype, cute.cosize(self.sdSt_layout)], + 128, + ] + sLSE: cute.struct.Align[ + cute.struct.MemRange[self.lse_dtype, cute.cosize(self.sLSE_layout)], + 128, + ] + sdPsum: cute.struct.Align[ + cute.struct.MemRange[self.dpsum_dtype, cute.cosize(self.sdPsum_layout)], + 128, + ] + sdQaccum: cute.struct.Align[ + cute.struct.MemRange[self.dqaccum_dtype, cute.cosize(self.sdQaccum_layout)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + LOG2_E = math.log2(math.e) + if const_expr(self.score_mod is None): + # Without score_mod: bake scale into log2 + softmax_scale_log2 = softmax_scale * LOG2_E + else: + # With score_mod: score_mod applied to S * softmax_scale, then use LOG2_E only + softmax_scale_log2 = LOG2_E + + if const_expr(window_size_left is not None): + window_size_left = Int32(window_size_left) + if const_expr(window_size_right is not None): + window_size_right = Int32(window_size_right) + + fastdiv_mods = None + if const_expr(aux_tensors is not None): + seqlen_q = cute.size(mQ.shape[0]) // ( + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 + ) + seqlen_k = cute.size(mK.shape[0]) + seqlen_q_divmod = FastDivmodDivisor(seqlen_q) + seqlen_k_divmod = FastDivmodDivisor(seqlen_k) + fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) + self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) + + if const_expr(self.use_block_sparsity or aux_tensors is not None): + assert all(x is None for x in (mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)), ( + "Variable sequence length is not supported yet for blocksparse or aux tensors in bwd" + ) + + self.kernel( + tma_tensor_Q, + tma_tensor_K, + tma_tensor_V, + mLSE, + mdPsum, + tma_tensor_dO, + mdV, + mdK, + mdQaccum, + mdV_tma_tensor, + mdK_tma_tensor, + mdQ_semaphore, + mdK_semaphore, + mdV_semaphore, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_dO, + tma_atom_dV, + tma_atom_dK, + self.sQ_layout, + self.sQt_layout, + self.sK_layout, + self.sV_layout, + self.sLSE_layout, + self.sdPsum_layout, + self.sdO_layout, + self.sdOt_layout, + self.sdSt_layout, + self.sdS_layout, + self.sKt_layout, + self.sdQaccum_layout, + self.sdKV_layout, + self.tP_layout, + self.tdS_layout, + self.tiled_mma_S, + self.tiled_mma_dP, + self.tiled_mma_dV, + self.tiled_mma_dK, + self.tiled_mma_dQ, + tiled_copy_r2s_dKV, + softmax_scale, + softmax_scale_log2, + window_size_left, + window_size_right, + tile_sched_params, + aux_tensors, + fastdiv_mods, + blocksparse_tensors, + ).launch( + grid=grid_dim, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk if cute.size(self.cluster_shape_mnk) > 1 else None, + smem=self.shared_storage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mLSE: cute.Tensor, + mdPsum: cute.Tensor, + mdO: cute.Tensor, + mdV: cute.Tensor, + mdK: cute.Tensor, + mdQaccum: cute.Tensor, + mdV_tma_tensor: Optional[cute.Tensor], + mdK_tma_tensor: Optional[cute.Tensor], + mdQ_semaphore: Optional[cute.Tensor], + mdK_semaphore: Optional[cute.Tensor], + mdV_semaphore: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_dO: cute.CopyAtom, + tma_atom_dV: Optional[cute.CopyAtom], + tma_atom_dK: Optional[cute.CopyAtom], + sQ_layout: cute.ComposedLayout, + sQt_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sLSE_layout: cute.Layout, + sdPsum_layout: cute.Layout, + sdO_layout: cute.ComposedLayout, + sdOt_layout: cute.ComposedLayout, + sdSt_layout: cute.ComposedLayout, + sdS_layout: cute.ComposedLayout, + sKt_layout: cute.ComposedLayout, + sdQaccum_layout: cute.Layout, + sdKV_layout: cute.ComposedLayout | cute.Layout, + tP_layout: cute.ComposedLayout, + tdS_layout: cute.ComposedLayout, + tiled_mma_S: cute.TiledMma, + tiled_mma_dP: cute.TiledMma, + tiled_mma_dV: cute.TiledMma, + tiled_mma_dK: cute.TiledMma, + tiled_mma_dQ: cute.TiledMma, + tiled_copy_r2s_dKV: cute.TiledCopy, + softmax_scale: cutlass.Float32, + softmax_scale_log2: cutlass.Float32, + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + tile_sched_params: ParamsBase, + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), + blocksparse_tensors: Optional[BlockSparseTensors] = None, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + # Prefetch tma descriptor + if warp_idx == self.load_warp_id: + with cute.arch.elect_one(): + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V) + cpasync.prefetch_descriptor(tma_atom_dO) + if const_expr(tma_atom_dV is not None): + cpasync.prefetch_descriptor(tma_atom_dV) + if const_expr(tma_atom_dK is not None): + cpasync.prefetch_descriptor(tma_atom_dK) + + cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (tiled_mma_S.thr_id.shape,), + ) + + # Alloc + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr.data_ptr() + dQ_cluster_full_mbar_ptr = storage.dQ_cluster_full_mbar_ptr.data_ptr() + dQ_cluster_empty_mbar_ptr = storage.dQ_cluster_empty_mbar_ptr.data_ptr() + + if warp_idx == 1: + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, cute.arch.WARP_SIZE * len(self.compute_warp_ids) + ) + if const_expr(self.cluster_reduce_dQ): + if warp_idx == 4: + for i in range(self.dQaccum_reduce_stage // 2): + cute.arch.mbarrier_init(dQ_cluster_full_mbar_ptr + i, 1) + cute.arch.mbarrier_init(dQ_cluster_empty_mbar_ptr + i, 1) + + # UMMA producers and AsyncThread consumers + pipeline_producer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + # Only 1 thread per warp will signal + pipeline_consumer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids) + ) + pipeline_S_P = cutlass.pipeline.PipelineUmmaAsync.create( + num_stages=1, + producer_group=pipeline_producer_group_MMA_AsyncThread, + consumer_group=pipeline_consumer_group_MMA_AsyncThread, + barrier_storage=storage.S_mbar_ptr.data_ptr(), + ) + pipeline_dP = cutlass.pipeline.PipelineUmmaAsync.create( + num_stages=1, + producer_group=pipeline_producer_group_MMA_AsyncThread, + consumer_group=pipeline_consumer_group_MMA_AsyncThread, + barrier_storage=storage.dP_mbar_ptr.data_ptr(), + ) + pipeline_dKV = cutlass.pipeline.PipelineUmmaAsync.create( + num_stages=2, + producer_group=pipeline_producer_group_MMA_AsyncThread, + consumer_group=pipeline_consumer_group_MMA_AsyncThread, + barrier_storage=storage.dKV_mbar_ptr.data_ptr(), + ) + pipeline_consumer_group_MMA_AsyncThread_dQ = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, + len(self.reduce_warp_ids), + ) # Compute + pipeline_dQ = cutlass.pipeline.PipelineUmmaAsync.create( + num_stages=1, + producer_group=pipeline_producer_group_MMA_AsyncThread, + consumer_group=pipeline_consumer_group_MMA_AsyncThread_dQ, + barrier_storage=storage.dQ_mbar_ptr.data_ptr(), + ) + + # AsyncThread producers and UMMA consumers + # Only 1 thread per warp will signal + pipeline_PdS_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids) + ) # Compute + pipeline_PdS_consumer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) + ) # MMA + pipeline_dS = cutlass.pipeline.PipelineAsyncUmma.create( + num_stages=1, + producer_group=pipeline_PdS_producer_group, + consumer_group=pipeline_PdS_consumer_group, + barrier_storage=storage.dS_mbar_ptr.data_ptr(), + ) + + # TMA producer and UMMA consumers + pipeline_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len([self.load_warp_id]) + ) + # The arrive count is the number of mcast size + pipeline_consumer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) * self.num_mcast_ctas_b + ) + pipeline_consumer_group_compute = cutlass.pipeline.CooperativeGroup( + # cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids) * self.num_mcast_ctas_b + cutlass.pipeline.Agent.Thread, + len(self.compute_warp_ids) * 1, + ) + pipeline_LSE = cutlass.pipeline.PipelineTmaAsync.create( + barrier_storage=storage.LSE_mbar_ptr.data_ptr(), + num_stages=self.Q_stage, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group_compute, + tx_count=self.tma_copy_bytes["LSE"], + # cta_layout_vmnk=cluster_layout_vmnk, + # init_wait=False, + ) + pipeline_dPsum = cutlass.pipeline.PipelineTmaAsync.create( + barrier_storage=storage.dPsum_mbar_ptr.data_ptr(), + num_stages=self.dO_stage, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group_compute, + tx_count=self.tma_copy_bytes["dPsum"], + # cta_layout_vmnk=cluster_layout_vmnk, + # init_wait=False, + ) + pipeline_Q = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.Q_mbar_ptr.data_ptr(), + num_stages=self.Q_stage, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group, + tx_count=self.tma_copy_bytes["Q"], + cta_layout_vmnk=cluster_layout_vmnk, + init_wait=False, + ) + pipeline_dO = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.dO_mbar_ptr.data_ptr(), + num_stages=self.dO_stage, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group, + tx_count=self.tma_copy_bytes["dO"], + cta_layout_vmnk=cluster_layout_vmnk, + init_wait=True, + ) + + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner, dtype=self.q_dtype) + sQt = cute.make_tensor( + cute.recast_ptr(sQ.iterator, sQt_layout.inner, dtype=self.q_dtype), sQt_layout.outer + ) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + sKt = cute.make_tensor(cute.recast_ptr(sK.iterator, sKt_layout.inner), sKt_layout.outer) + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + sdSt = storage.sdS.get_tensor(sdSt_layout.outer, swizzle=sdSt_layout.inner) + sdS = cute.make_tensor(cute.recast_ptr(sdSt.iterator, sdS_layout.inner), sdS_layout.outer) + sdO = storage.sdO.get_tensor( + sdO_layout.outer, swizzle=sdO_layout.inner, dtype=self.do_dtype + ) + sdOt = cute.make_tensor( + cute.recast_ptr(sdO.iterator, sdOt_layout.inner, dtype=self.do_dtype), sdOt_layout.outer + ) + sLSE = storage.sLSE.get_tensor(sLSE_layout) + sdPsum = storage.sdPsum.get_tensor(sdPsum_layout) + if const_expr(not self.dKV_postprocess): + sdV = storage.sdO.get_tensor( + sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dv_dtype + ) + sdK = storage.sQ.get_tensor( + sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dk_dtype + ) + else: + sdV = storage.sdO.get_tensor(sdKV_layout, dtype=self.dv_dtype) + sdK = storage.sQ.get_tensor(sdKV_layout, dtype=self.dk_dtype) + + # Buffer sizing is guaranteed by max(...) in SharedStorage declarations + # for both sQ (reused as sdK) and sdO (reused as sdV) + + sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout) + + # TMEM + # This is a fake tensor, by right need to retrieve tmem_ptr. But we know that we always + # request 512 columns of tmem, so we know that it starts at 0. + tmem_ptr = cute.make_ptr(Float32, 0, mem_space=cute.AddressSpace.tmem, assumed_align=16) + # S + thr_mma_S = tiled_mma_S.get_slice(0) + Sacc_shape = thr_mma_S.partition_shape_C(self.mma_tiler_kq[:2]) # (M, N) + tStS = thr_mma_S.make_fragment_C(Sacc_shape) + # (MMA, MMA_M, MMA_N) + tStS = cute.make_tensor(tmem_ptr + self.tmem_S_offset, tStS.layout) + # dP + thr_mma_dP = tiled_mma_dP.get_slice(0) + dPacc_shape = thr_mma_dP.partition_shape_C(self.mma_tiler_vdo[:2]) + tdPtdP = thr_mma_dP.make_fragment_C(dPacc_shape) + tdPtdP = cute.make_tensor(tmem_ptr + self.tmem_dP_offset, tdPtdP.layout) + # dV + thr_mma_dV = tiled_mma_dV.get_slice(0) + dvacc_shape = thr_mma_dV.partition_shape_C(self.mma_tiler_pdo[:2]) + tdVtdV = thr_mma_dV.make_fragment_C(dvacc_shape) + tdVtdV = cute.make_tensor(tmem_ptr + self.tmem_dV_offset, tdVtdV.layout) + tP = cute.make_tensor( + cute.recast_ptr(tmem_ptr + self.tmem_P_offset, dtype=self.do_dtype), tP_layout.outer + ) + # dK + thr_mma_dK = tiled_mma_dK.get_slice(0) + dkacc_shape = thr_mma_dK.partition_shape_C(self.mma_tiler_dsq[:2]) + tdKtdK = thr_mma_dK.make_fragment_C(dkacc_shape) + tdKtdK = cute.make_tensor(tmem_ptr + self.tmem_dK_offset, tdKtdK.layout) + tdS = cute.make_tensor( + cute.recast_ptr(tmem_ptr + self.tmem_dS_offset, dtype=self.ds_dtype), tdS_layout.outer + ) + # dQ + thr_mma_dQ = tiled_mma_dQ.get_slice(0) + dQacc_shape = thr_mma_dQ.partition_shape_C(self.mma_tiler_dsk[:2]) + tdQtdQ = thr_mma_dQ.make_fragment_C(dQacc_shape) + tdQtdQ = cute.make_tensor(tmem_ptr + self.tmem_dQ_offset, tdQtdQ.layout) + + block_info = BlockInfo( + self.tile_m, + # self.tile_n, + self.tile_n * self.cluster_shape_mnk[0], # careful, this case is not very well-tested + self.is_causal, + self.is_local, + False, # is_split_kv + window_size_left, + window_size_right, + qhead_per_kvhead_packgqa=1, + ) + SeqlenInfoCls = partial( + SeqlenInfoQK.create, + seqlen_q_static=mQ.shape[0], + seqlen_k_static=mK.shape[0], + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, + tile_m=self.tile_m, + tile_n=self.tile_n, + ) + TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) + + AttentionMaskCls = partial( + AttentionMask, + self.tile_m, + self.tile_n, + swap_AB=True, + window_size_left=window_size_left, + window_size_right=window_size_right, + ) + + # EMPTY + # (15) + if warp_idx == self.empty_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + + # EPI + # (14) + if warp_idx == self.epi_warp_id: + # currently no-op, could use for tma store/reduce + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + + # LOAD + # (13) + if warp_idx == self.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + self.load( + thr_mma_S, + thr_mma_dP, + thr_mma_dV, + mQ, + mK, + mV, + mLSE, + mdPsum, + mdO, + sQ, + sK, + sV, + sLSE, + sdPsum, + sdO, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_dO, + pipeline_Q, + pipeline_dO, + pipeline_LSE, + pipeline_dPsum, + cluster_layout_vmnk, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + blocksparse_tensors, + should_load_Q=True, + should_load_dO=True, + ) + + # MMA + # (12) + if warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + # Alloc tmem buffer + tmem_alloc_cols = Int32(self.tmem_alloc_cols) + cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf) + cute.arch.sync_warp() + + self.mma( + tiled_mma_S, + tiled_mma_dP, + tiled_mma_dV, + tiled_mma_dK, + tiled_mma_dQ, + sQ, + sQt, + sK, + sV, + sdO, + sdOt, + sdSt, + sdS, + sKt, + tP, + tdS, + tStS, + tdPtdP, + tdVtdV, + tdKtdK, + tdQtdQ, + pipeline_Q.make_consumer(), + pipeline_dO, + pipeline_S_P, + pipeline_dS, + pipeline_dKV, + pipeline_dP, + pipeline_dQ, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + blocksparse_tensors, + ) + cute.arch.relinquish_tmem_alloc_permit() + tmem_ptr = cute.arch.retrieve_tmem_ptr( + Float32, alignment=16, ptr_to_buffer_holding_addr=storage.tmem_holding_buf + ) + + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + tmem_alloc_cols = Int32(self.tmem_alloc_cols) + cute.arch.dealloc_tmem(tmem_ptr, tmem_alloc_cols, is_two_cta=False) + + # Compute + # (4, 5, 6, 7, 8, 9, 10, 11) --> 8 warps + if warp_idx >= self.compute_warp_ids[0] and warp_idx <= self.compute_warp_ids[-1]: + cute.arch.warpgroup_reg_alloc(self.num_regs_compute) # 8 warps + self.compute_loop( + thr_mma_S, + thr_mma_dP, + thr_mma_dV, + thr_mma_dK, + tStS, + sLSE, + sdPsum, + tdVtdV, + tdKtdK, + mdV, + mdK, + sdS, + tdPtdP, + pipeline_LSE, + pipeline_dPsum, + pipeline_S_P, + pipeline_dS, + pipeline_dKV, + pipeline_dP, + softmax_scale, + softmax_scale_log2, + block_info, + SeqlenInfoCls, + AttentionMaskCls, + TileSchedulerCls, + sdV, + sdK, + mdV_tma_tensor, + mdK_tma_tensor, + tma_atom_dV, + tma_atom_dK, + tiled_copy_r2s_dKV, + mdK_semaphore, + mdV_semaphore, + aux_tensors, + fastdiv_mods, + blocksparse_tensors, + ) + cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr) + + # Reduce + # (0, 1, 2, 3) - dQ + if warp_idx >= self.reduce_warp_ids[0] and warp_idx <= self.reduce_warp_ids[-1]: + cute.arch.warpgroup_reg_alloc(self.num_regs_reduce) + self.dQacc_reduce( + mdQaccum, + sdQaccum, + thr_mma_dQ, + tdQtdQ, + pipeline_dQ, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + mdQ_semaphore, + blocksparse_tensors, + ) + + return + + @cute.jit + def load( + self, + thr_mma_S: cute.core.ThrMma, + thr_mma_dP: cute.core.ThrMma, + thr_mma_dV: cute.core.ThrMma, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mLSE: cute.Tensor, + mdPsum: cute.Tensor, + mdO: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sLSE: cute.Tensor, + sdPsum: cute.Tensor, + sdO: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_dO: cute.CopyAtom, + pipeline_Q: PipelineAsync, + pipeline_dO: PipelineAsync, + pipeline_LSE: PipelineAsync, + pipeline_dPsum: PipelineAsync, + cluster_layout_vmnk: cute.Layout, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors] = None, + should_load_Q: bool = True, + should_load_dO: bool = True, + ): + producer_state_Q_LSE = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.Q_stage + ) + producer_state_dO_dPsum = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.dO_stage + ) + + # Compute multicast mask for Q & dO buffer full + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + q_do_mcast_mask = None + if const_expr(self.is_q_do_mcast): + q_do_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + n_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + m_block_min, m_block_max = block_info.get_m_block_min_max( + seqlen, n_block // self.cluster_shape_mnk[0] + ) + head_idx_kv = head_idx // self.qhead_per_kvhead + mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] + mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] + mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv] + if const_expr(not seqlen.has_cu_seqlens_q): + mdO_cur = mdO[None, None, head_idx, batch_idx] + else: + mdO_cur = cute.domain_offset((0, seqlen.offset_q), mdO[None, None, head_idx]) + mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2, padded=True)[None, head_idx] + mdPsum_cur = seqlen.offset_batch_Q(mdPsum, batch_idx, dim=2, padded=True)[ + None, head_idx + ] + + gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_kq, mode=[0, 2]), (n_block, 0)) + tSgK = thr_mma_S.partition_A(gK) + gV = cute.local_tile(mV_cur, cute.select(self.mma_tiler_vdo, mode=[0, 2]), (n_block, 0)) + tdPgV = thr_mma_dP.partition_A(gV) + gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_kq, mode=[1, 2]), (None, 0)) + tSgQ = thr_mma_S.partition_B(gQ) + gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,)) + gdPsum = cute.local_tile(mdPsum_cur, (self.tile_m,), (None,)) + gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None)) + tdPgdO = thr_mma_dV.partition_B(gdO) + + load_K, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_K, 0, cute.make_layout(1), tSgK, sK, single_stage=True + ) + load_V, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_V, + 0, + cute.make_layout(1), + tdPgV, + sV, + single_stage=True, + ) + b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) + load_Q, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, + cta_coord=block_in_cluster_coord_vmnk[1], + cta_layout=b_cta_layout, + src_tensor=tSgQ, + dst_tensor=sQ, + mcast_mask=q_do_mcast_mask, + ) + load_Q = copy_utils.tma_producer_copy_fn(load_Q, pipeline_Q) + load_dO, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_dO, + cta_coord=block_in_cluster_coord_vmnk[1], + cta_layout=b_cta_layout, + src_tensor=tdPgdO, + dst_tensor=sdO, + mcast_mask=q_do_mcast_mask, + ) + load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_dO) + copy_atom_stats = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), Float32) + copy_stats = partial(cute.copy, copy_atom_stats) + # copy_atom_stats = cute.make_copy_atom(cpasync.CopyBulkG2SMulticastOp(), Float32) + # sLSE = cute.logical_divide(sLSE, (64,))[(None, block_in_cluster_coord_vmnk[1]), None] + # gLSE = cute.logical_divide(gLSE, (64,))[(None, block_in_cluster_coord_vmnk[1]), None] + # sdPsum = cute.logical_divide(sdPsum, (64,))[(None, block_in_cluster_coord_vmnk[1]), None] + # gdPsum = cute.logical_divide(gdPsum, (64,))[(None, block_in_cluster_coord_vmnk[1]), None] + # copy_stats = partial(cute.copy, copy_atom_stats, mcast_mask=q_do_mcast_mask) + + # some tiles might be empty due to block sparsity + if const_expr(self.use_block_sparsity): + total_m_block_cnt = get_total_q_block_count_bwd( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, + ) + process_tile = total_m_block_cnt > Int32(0) + else: + process_tile = ( + const_expr(not self.is_local and not self.is_varlen_q) + or m_block_min < m_block_max + ) + + if process_tile: + if const_expr(self.use_block_sparsity): + producer_state_Q_LSE, producer_state_dO_dPsum = ( + produce_block_sparse_q_loads_bwd_sm100( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + producer_state_Q_LSE, + producer_state_dO_dPsum, + pipeline_Q, + pipeline_LSE, + pipeline_dO, + pipeline_dPsum, + load_K, + load_V, + load_Q, + load_dO, + copy_stats, + gLSE, + sLSE, + gdPsum, + sdPsum, + self.tma_copy_bytes["K"], + self.tma_copy_bytes["V"], + should_load_Q=should_load_Q, + should_load_dO=should_load_dO, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, + ) + ) + else: + first_m_block = m_block_min + + # First iteration: load K together w Q & LSE, then V together w dO & dPsum + if const_expr(should_load_Q): + pipeline_Q.producer_acquire( + producer_state_Q_LSE, extra_tx_count=self.tma_copy_bytes["K"] + ) + load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE)) + load_Q(first_m_block, producer_state=producer_state_Q_LSE) + pipeline_Q.producer_commit(producer_state_Q_LSE) + pipeline_LSE.producer_acquire(producer_state_Q_LSE) + with cute.arch.elect_one(): + copy_stats( + gLSE[None, first_m_block], + sLSE[None, producer_state_Q_LSE.index], + mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), + ) + producer_state_Q_LSE.advance() + if const_expr(should_load_dO): + pipeline_dO.producer_acquire( + producer_state_dO_dPsum, extra_tx_count=self.tma_copy_bytes["V"] + ) + load_V( + tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum) + ) + load_dO(first_m_block, producer_state=producer_state_dO_dPsum) + pipeline_dO.producer_commit(producer_state_dO_dPsum) + pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) + with cute.arch.elect_one(): + copy_stats( + gdPsum[None, first_m_block], + sdPsum[None, producer_state_dO_dPsum.index], + mbar_ptr=pipeline_dPsum.producer_get_barrier( + producer_state_dO_dPsum + ), + ) + producer_state_dO_dPsum.advance() + + # Dense path: iterate from m_block_min+1 to m_block_max + for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): + if const_expr(should_load_Q): + pipeline_Q.producer_acquire(producer_state_Q_LSE) + load_Q(m_block, producer_state=producer_state_Q_LSE) + pipeline_Q.producer_commit(producer_state_Q_LSE) + pipeline_LSE.producer_acquire(producer_state_Q_LSE) + with cute.arch.elect_one(): + copy_stats( + gLSE[None, m_block], + sLSE[None, producer_state_Q_LSE.index], + mbar_ptr=pipeline_LSE.producer_get_barrier( + producer_state_Q_LSE + ), + ) + producer_state_Q_LSE.advance() + if const_expr(should_load_dO): + pipeline_dO.producer_acquire(producer_state_dO_dPsum) + load_dO(m_block, producer_state=producer_state_dO_dPsum) + pipeline_dO.producer_commit(producer_state_dO_dPsum) + pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) + with cute.arch.elect_one(): + copy_stats( + gdPsum[None, m_block], + sdPsum[None, producer_state_dO_dPsum.index], + mbar_ptr=pipeline_dPsum.producer_get_barrier( + producer_state_dO_dPsum + ), + ) + producer_state_dO_dPsum.advance() + + if const_expr(should_load_Q): + pipeline_Q.producer_tail( + producer_state_Q_LSE.clone() + ) # will hang if we don't clone + pipeline_LSE.producer_tail(producer_state_Q_LSE) + if const_expr(should_load_dO): + pipeline_dO.producer_tail(producer_state_dO_dPsum.clone()) + pipeline_dPsum.producer_tail(producer_state_dO_dPsum) + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit + def mma( + self, + tiled_mma_S: cute.TiledMma, + tiled_mma_dP: cute.TiledMma, + tiled_mma_dV: cute.TiledMma, + tiled_mma_dK: cute.TiledMma, + tiled_mma_dQ: cute.TiledMma, + sQ: cute.Tensor, + sQt: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sdO: cute.Tensor, + sdOt: cute.Tensor, + sdSt: cute.Tensor, + sdS: cute.Tensor, + sKt: cute.Tensor, + tP: cute.Tensor, + tdS: cute.Tensor, + tStS: cute.Tensor, + tdPtdP: cute.Tensor, + tdVtdV: cute.Tensor, + tdKtdK: cute.Tensor, + tdQtdQ: cute.Tensor, + pipeline_Q_consumer: PipelineConsumer, + pipeline_dO: PipelineAsync, + pipeline_S_P: PipelineAsync, + pipeline_dS: PipelineAsync, + pipeline_dKV: PipelineAsync, + pipeline_dP: PipelineAsync, + pipeline_dQ: PipelineAsync, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors] = None, + ): + # [2025-10-21] For reasons I don't understand, putting these partitioning in the main + # kernel (before warp specialization) is a lot slower tha putting them here. + # Partition smem / tmem tensors + # S = K @ Q.T + tSrK = tiled_mma_S.make_fragment_A(sK) + tSrQ = tiled_mma_S.make_fragment_B(sQ) + # dP = V @ dO.T + tdPrV = tiled_mma_dP.make_fragment_A(sV) + tdPrdOt = tiled_mma_dP.make_fragment_B(sdOt) + # dK = dS.T @ Q + if const_expr(self.use_smem_dS_for_mma_dK): + tdKrdS = tiled_mma_dK.make_fragment_A(sdSt) + else: + tdKrdS = tiled_mma_dK.make_fragment_A(tdS) + tdKrQ = tiled_mma_dK.make_fragment_B(sQt) + # dQ = dS @ K + tdQrdS = tiled_mma_dQ.make_fragment_A(sdS) + tdQrK = tiled_mma_dQ.make_fragment_B(sKt) + # dV = P @ dO.T + tdVrdO = tiled_mma_dV.make_fragment_B(sdO) + tdVrP = tiled_mma_dV.make_fragment_A(tP) + + # mma_qk_fn = partial(gemm_w_idx, tiled_mma_S, tStS, tSrK, tSrQ, zero_init=True) + mma_qk_fn = partial( + gemm_ptx_w_idx, tiled_mma_S, tStS, tSrK, tSrQ, sA=sK, sB=sQ, zero_init=True + ) + # mma_dov_fn = partial(gemm_w_idx, tiled_mma_dP, tdPtdP, tdPrV, tdPrdOt, zero_init=True) + mma_dov_fn = partial( + gemm_ptx_w_idx, + tiled_mma_dP, + tdPtdP, + tdPrV, + tdPrdOt, + sA=sV, + sB=sdOt, + zero_init=True, + ) + # mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, tdVtdV, tdVrP, tdVrdO) + mma_pdo_fn = partial( + gemm_ptx_w_idx, + tiled_mma_dV, + tdVtdV, + tdVrP, + tdVrdO, + sA=None, + sB=sdO, + tA_addr=self.tmem_P_offset, + ) + mma_dsk_fn = partial(gemm_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, zero_init=True) + # mma_dsk_fn = partial( + # gemm_ptx_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, sA=sdS, sB=sKt, zero_init=True + # ) + if const_expr(self.use_smem_dS_for_mma_dK): + mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ) + else: + # Need to explicitly pass in tA_addr for correctness + mma_dsq_fn = partial( + gemm_ptx_w_idx, + tiled_mma_dK, + tdKtdK, + tdKrdS, + tdKrQ, + sA=None, + sB=sQt, + tA_addr=self.tmem_dS_offset, + ) + + consumer_state_dO = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage + ) + producer_phase_acc = Int32(1) # For S & P, dP, dQ + consumer_state_dS = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, 1 + ) + # producer_state_dKV = cutlass.pipeline.make_pipeline_state( + # cutlass.pipeline.PipelineUserType.Producer, 2 + # ) + producer_phase_dKV = Int32(1) + cta_group = pipeline_S_P.cta_group + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + n_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) # must be seqlen_k + m_block_min, m_block_max = block_info.get_m_block_min_max( + seqlen, n_block // self.cluster_shape_mnk[0] + ) + + if const_expr(self.use_block_sparsity): + block_iter_count = get_total_q_block_count_bwd( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, + ) + process_tile = block_iter_count > Int32(0) + else: + block_iter_count = m_block_max - m_block_min + process_tile = ( + const_expr(not self.is_local and not self.is_varlen_q) + or m_block_min < m_block_max + ) + + if process_tile: + accumulate_dK = False + # ----------------------------------------------------------- + ###### Prologue + # ----------------------------------------------------------- + # 1. S = Q0 @ K.T + # 2. dP = V @ dO.T + # 3. dV = P @ dO + # 1) S = Q0 @ K.T + handle_Q = pipeline_Q_consumer.wait_and_advance() + pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) + mma_qk_fn(B_idx=handle_Q.index) + # Don't release Q yet + pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) + + # 2) dP = V @ dO.T + pipeline_dO.consumer_wait(consumer_state_dO) + pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) + # dQ uses the same tmem as dP + pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc) + mma_dov_fn(B_idx=consumer_state_dO.index) + # Don't release dO yet + pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) + + producer_phase_acc ^= 1 + # 3) dV = P.T @ dO + # wait for P to be ready, which uses the same tmem as S + pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) + mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=True) + pipeline_dO.consumer_release(consumer_state_dO) + consumer_state_dO.advance() + # ----------------------------------------------------------- + ###### MAIN LOOP + # ----------------------------------------------------------- + # 1. S = K @ Q.T + # 2. dQ = dS @ K + # 3. dK = dS.T @ Q + # 4. dP = V @ dO.T + # 5. dV = P.T @ dO + + # For block sparsity, we use block_iter_count; for dense, use m_block range + # MMA doesn't need actual m_block indices, just the iteration count + main_loop_iters = ( + block_iter_count - 1 + if const_expr(self.use_block_sparsity) + else m_block_max - m_block_min - 1 + ) + for _ in cutlass.range(main_loop_iters, unroll=1): + # 1) S = K @ Q_i + handle_Q_next = pipeline_Q_consumer.wait_and_advance() + # Don't need to wait for S, as P must have been ready ealier, i.e., S is ready + mma_qk_fn(B_idx=handle_Q_next.index) + pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) + + # 2-3) + # Do dK = dS.T @ Q, then dQ = dS @ K if dS in tmem for first mma + # Otherwise, reverse order + pipeline_dS.consumer_wait(consumer_state_dS) + + if const_expr(self.use_smem_dS_for_mma_dK): + mma_dsk_fn() + pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) + accumulate_dK = True + handle_Q.release() + else: + mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) + accumulate_dK = True + handle_Q.release() + mma_dsk_fn() + pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + + # dP uses the same tmem as dQ + # However, if dS is ready, then dP must have been ready, + # so we don't need this wait before mma_dsk_fn() + # pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) + + pipeline_dS.consumer_release(consumer_state_dS) + consumer_state_dS.advance() + + # 4) dP = V @ dO.T + pipeline_dO.consumer_wait(consumer_state_dO) + # dQ uses the same tmem as dP + pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc) + mma_dov_fn(B_idx=consumer_state_dO.index) + pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) + + producer_phase_acc ^= 1 + # 5) dV += P @ dO + # wait for P to be ready, which uses the same tmem as S + pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) + mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=False) + pipeline_dO.consumer_release(consumer_state_dO) + consumer_state_dO.advance() + + handle_Q = handle_Q_next + + pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) + + # signal to the epilogue that dV is ready + # pipeline_dKV.producer_acquire(producer_state_dKV) + pipeline_dKV.sync_object_empty.wait(0, producer_phase_dKV) + # pipeline_dKV.producer_commit(producer_state_dKV) + pipeline_dKV.sync_object_full.arrive(0, pipeline_dKV.producer_mask, cta_group) + # producer_state_dKV.advance() + # pipeline_dKV.producer_acquire(producer_state_dKV) + pipeline_dKV.sync_object_empty.wait(1, producer_phase_dKV) + + # ----------------------------------------------------------- + ###### Remaining 2 + # ----------------------------------------------------------- + # 1) dK += dS.T @ Q + pipeline_dS.consumer_wait(consumer_state_dS) + mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) + # signal to the epilogue that dK is ready + # pipeline_dKV.producer_commit(producer_state_dKV) + pipeline_dKV.sync_object_full.arrive(1, pipeline_dKV.producer_mask, cta_group) + # producer_state_dKV.advance() + producer_phase_dKV ^= 1 + + # 2) dQ = dS @ K + # dS is done, so dP must have been ready, we don't need to wait + mma_dsk_fn() + pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) + # Wait until dQ is done before releasing Q, since K and Q0 uses the same mbarrier + handle_Q.release() + pipeline_dS.consumer_release(consumer_state_dS) + consumer_state_dS.advance() + + producer_phase_acc ^= 1 + + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + # Currently it hangs if we have this S_P.producer_tail, will need to understand why + # pipeline_S_P.producer_tail(producer_state_S_P) + # pipeline_dP.producer_tail(producer_state_dP) + # pipeline_dKV.producer_tail(producer_state_dKV) + # pipeline_dQ.producer_tail(producer_state_dQ) + + @cute.jit + def split_wg( + self, + t: cute.Tensor, + wg_idx: cutlass.Int32, + num_wg: cutlass.Constexpr[int], + ): + reduced_shape = cute.product_each(t.shape) + rank = len(reduced_shape) + if const_expr(reduced_shape[1] > 1): + assert rank >= 2, "Need rank >= 2 for t in split_wg" + t = cute.logical_divide(t, (reduced_shape[0], reduced_shape[1] // num_wg)) + coord = (None, (None, wg_idx)) + (None,) * (rank - 2) + else: + assert rank >= 3, "Need rank >= 3 for t in split_wg" + if const_expr(rank == 3): + t = cute.logical_divide( + t, (reduced_shape[0], reduced_shape[1], reduced_shape[2] // num_wg) + ) + coord = ( + None, + None, + (None, wg_idx), + ) + (None,) * (rank - 3) + else: + t = cute.logical_divide( + t, + ( + reduced_shape[0], + reduced_shape[1], + reduced_shape[2], + reduced_shape[3] // num_wg, + ), + ) + coord = ( + None, + None, + None, + (None, wg_idx), + ) + (None,) * (rank - 4) + return t[coord] + + @cute.jit + def apply_score_mod( + self, + tSrS_t2r, + thr_copy_t2r, + thr_mma_S, + batch_idx, + head_idx, + m_block, + n_block, + softmax_scale, + seqlen_info, + aux_tensors=None, + fastdiv_mods=(None, None), + ): + """Apply forward score modification for SM100 backward pass.""" + # In bwd, S is computed as K @ Q.T so dimensions are (tile_n, tile_m) + cS = cute.make_identity_tensor((self.tile_n, self.tile_m)) + cS = cute.domain_offset((n_block * self.tile_n, m_block * self.tile_m), cS) + tScS = thr_mma_S.partition_C(cS) + tScS_idx = thr_copy_t2r.partition_D(tScS) + + apply_score_mod_inner( + tSrS_t2r, + tScS_idx, + self.score_mod, + batch_idx, + head_idx, + softmax_scale, + self.vec_size, + self.qk_acc_dtype, + aux_tensors, + fastdiv_mods, + seqlen_info, + constant_q_idx=None, + qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + transpose_indices=True, + ) + + @cute.jit + def apply_score_mod_bwd( + self, + grad_tensor, + score_tensor, + index_tensor, + batch_idx, + head_idx, + softmax_scale, + seqlen_info, + aux_tensors=None, + fastdiv_mods=(None, None), + ): + """Apply backward score modification (joint graph) for SM100.""" + apply_score_mod_bwd_inner( + grad_tensor, + score_tensor, + index_tensor, + self.score_mod_bwd, + batch_idx, + head_idx, + softmax_scale, + self.vec_size, + self.qk_acc_dtype, + aux_tensors, + fastdiv_mods, + seqlen_info, + constant_q_idx=None, + qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + transpose_indices=True, + ) + + @cute.jit + def compute_loop( + self, + thr_mma_S: cute.core.ThrMma, + thr_mma_dP: cute.core.ThrMma, + thr_mma_dV: cute.core.ThrMma, + thr_mma_dK: cute.core.ThrMma, + tStS: cute.Tensor, + sLSE: cute.Tensor, + sdPsum: cute.Tensor, + tdVtdV: cute.Tensor, + tdKtdK: cute.Tensor, + mdV: cute.Tensor, + mdK: cute.Tensor, + sdS: cute.Tensor, + tdPtdP: cute.Tensor, + pipeline_LSE: PipelineAsync, + pipeline_dPsum: PipelineAsync, + pipeline_S_P: PipelineAsync, + pipeline_dS: PipelineAsync, + pipeline_dKV: PipelineAsync, + pipeline_dP: PipelineAsync, + softmax_scale: cutlass.Float32, + softmax_scale_log2: cutlass.Float32, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + AttentionMaskCls: Callable, + TileSchedulerCls: Callable, + sdV: Optional[cute.Tensor], + sdK: Optional[cute.Tensor], + mdV_tma_tensor: Optional[cute.Tensor], + mdK_tma_tensor: Optional[cute.Tensor], + tma_atom_dV: Optional[cute.CopyAtom], + tma_atom_dK: Optional[cute.CopyAtom], + tiled_copy_r2s_dKV: Optional[cute.TiledCopy], + mdK_semaphore: Optional[cute.Tensor], + mdV_semaphore: Optional[cute.Tensor], + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), + blocksparse_tensors: Optional[BlockSparseTensors] = None, + ): + sLSE_2D = cute.make_tensor( + sLSE.iterator, + cute.make_layout( + (self.tile_m, self.tile_n, self.Q_stage), + stride=(1, 0, cute.round_up(self.tile_m, 64)), + ), + ) + sdPsum_2D = cute.make_tensor( + sdPsum.iterator, + cute.make_layout( + (self.tile_m, self.tile_n, self.dO_stage), + stride=(1, 0, cute.round_up(self.tile_m, 64)), + ), + ) + # if const_expr(self.SdP_swapAB): + if const_expr(True): + sLSE_2D = utils.transpose_view(sLSE_2D) + sdPsum_2D = utils.transpose_view(sdPsum_2D) + + # tix: [128...384] 8 warps + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # 4-11 + tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids)) + # tidx = cute.arch.thread_idx()[0] - (cute.arch.WARP_SIZE * self.compute_warp_ids[0]) + dp_idx = tidx % 128 + num_wg = len(self.compute_warp_ids) // 4 # 2 + # wg_idx: + # 0: [256...384] + # 1: [128...256] + + tileP_f32_like = self.mma_tiler_kq[0] // 32 * self.v_dtype.width # 64 for tile_n = 128 + # tStS has shape ((128, 128), 1, 1), tStP has shape ((128, 64), 1, 1) + # tP overlap with tS + tStP = cute.composition(tStS, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) + tStP = cute.make_tensor(tStS.iterator, tStP.layout) # Otherwise the tmem address is wrong + tScS = thr_mma_S.partition_C(cute.make_identity_tensor(self.mma_tiler_kq[:2])) + tScP = cute.composition(tScS, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) + # tdS overlap with tdP + tdPtdS = cute.composition(tdPtdP, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) + tdPcdP = thr_mma_dP.partition_C(cute.make_identity_tensor(self.mma_tiler_vdo[:2])) + tdPcdS = cute.composition(tdPcdP, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32 + ) + + # tmem -> rmem + thr_copy_t2r = copy_utils.make_tmem_copy(tmem_load_atom, num_wg).get_slice(tidx) + tStS_t2r = thr_copy_t2r.partition_S(tStS) # (((32, 32), 1), 2, 1, 1) + tdPtdP_t2r = thr_copy_t2r.partition_S(tdPtdP) + tScS_t2r = thr_copy_t2r.partition_D(tScS) # ((32, 1), 2, 1, 1) + t0ScS_t2r = thr_copy_t2r.get_slice(0).partition_D(tScS) # ((32, 1), 2, 1, 1) + # ((32, 1), 2, 1, 1, STAGE) + tSsLSE = thr_copy_t2r.partition_D(thr_mma_S.partition_C(sLSE_2D)) + tSsdPsum = thr_copy_t2r.partition_D(thr_mma_dP.partition_C(sdPsum_2D)) + # rmem -> tmem + thr_copy_r2t = copy_utils.make_tmem_copy(tmem_store_atom, num_wg).get_slice(tidx) + tScP_r2t = thr_copy_r2t.partition_S(tScP) + tStP_r2t = thr_copy_r2t.partition_D(tStP) + tdPcdS_r2t = thr_copy_r2t.partition_S(tdPcdS) + tdPtdS_r2t = thr_copy_r2t.partition_D(tdPtdS) + # rmem -> smem + # This part is a bit iffy, we might be making a lot of assumptions here + copy_atom_r2s = sm100_utils_basic.get_smem_store_op( + LayoutEnum.ROW_MAJOR, self.ds_dtype, Float32, thr_copy_t2r + ) + thr_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, thr_copy_t2r).get_slice(tidx) + # We assume the swizzle (i.e. layout.inner) stays the same + sdS_layout = sm100_utils_basic.make_smem_layout_epi( + self.ds_dtype, LayoutEnum.ROW_MAJOR, (self.tile_n, self.tile_m), 1 + ).outer # ((8,16), (64,2), (1, 1)) + sdS_layout = cute.slice_(sdS_layout, (None, None, 0)) # ((8,16), (64,2)) + # Need to group into 1 mode to be compatible w thr_copy_r2s + sdS_layout = cute.make_layout((sdS_layout.shape,), stride=(sdS_layout.stride,)) + sdS_epi = cute.make_tensor(sdS.iterator, sdS_layout) + tRS_sdS = thr_copy_r2s.partition_D(sdS_epi) + + consumer_state_S_P_dP = pipeline.make_pipeline_state( # Our impl has shortcut for stage==1 + cutlass.pipeline.PipelineUserType.Consumer, 1 + ) + # consumer_phase_S_P_dP = Int32(0) + producer_state_dS = pipeline.make_pipeline_state( # Our impl has shortcut for stage==1 + cutlass.pipeline.PipelineUserType.Producer, 1 + ) + consumer_state_dKV = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, 2 + ) + consumer_state_LSE = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage + ) + # consumer_state_dPsum = cutlass.pipeline.make_pipeline_state( + consumer_state_dPsum = pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage + ) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + n_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + m_block_min, m_block_max = block_info.get_m_block_min_max( + seqlen, n_block // self.cluster_shape_mnk[0] + ) + mask = AttentionMaskCls(seqlen) + # TODO: condition mask_seqlen + mask_fn = partial( + mask.apply_mask_sm100_transposed, + tScS_t2r=tScS_t2r, + t0ScS_t2r=t0ScS_t2r, + n_block=n_block, + mask_seqlen=True, + mask_causal=self.is_causal, + mask_local=self.is_local, + mask_mod=self.mask_mod, + batch_idx=batch_idx, + head_idx=head_idx, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + + # prefetch_LSE = not self.is_causal + prefetch_LSE = False + + # some tiles might be empty due to block sparsity + if const_expr(self.use_block_sparsity): + ( + curr_q_cnt, + curr_q_idx, + curr_full_cnt, + curr_full_idx, + loop_count, + ) = get_block_sparse_iteration_info_bwd( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, + ) + process_tile = loop_count > Int32(0) + else: + process_tile = ( + const_expr(not self.is_local and not self.is_varlen_q) + or m_block_min < m_block_max + ) + loop_count = m_block_max - m_block_min + + # Mainloop + # Block sparsity: iterate over sparse m_block count and derive actual m_block + # from Q_IDX/FULL_Q_IDX tensors. Dense: iterate m_block_min..m_block_max directly. + for iter_idx in cutlass.range(loop_count, unroll=1): + if const_expr(self.use_block_sparsity): + m_block, is_full_block = get_m_block_from_iter_bwd( + iter_idx, + curr_q_cnt, + curr_q_idx, + curr_full_cnt, + curr_full_idx, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, + ) + m_block_oob = m_block >= m_block_max + else: + m_block = m_block_min + iter_idx + m_block_oob = False + is_full_block = False + # Prefetch 1 stage of LSE + pipeline_LSE.consumer_wait(consumer_state_LSE) + tSrLSE_s2r = cute.make_fragment(tScS_t2r[None, 0, 0, 0].shape, Float32) + if const_expr(prefetch_LSE and not self.shuffle_LSE): + cute.autovec_copy(tSsLSE[None, 0, 0, 0, consumer_state_LSE.index], tSrLSE_s2r) + + pipeline_S_P.consumer_wait(consumer_state_S_P_dP) + # pipeline_S_P.sync_object_full.wait(0, consumer_phase_S_P_dP) + #### TMEM->RMEM (Load S from TMEM) + tSrS_t2r = cute.make_fragment(tScS_t2r.shape, Float32) + cute.copy(thr_copy_t2r, tStS_t2r, tSrS_t2r) + if const_expr(self.score_mod_bwd is not None): + tSrS_pre = cute.make_fragment_like(tSrS_t2r) + cute.autovec_copy(tSrS_t2r, tSrS_pre) + + if const_expr(self.score_mod is not None): + # Apply score_mod FIRST -> matches forward + self.apply_score_mod( + tSrS_t2r, + thr_copy_t2r, + thr_mma_S, + batch_idx, + head_idx, + m_block, + n_block, + softmax_scale, + seqlen, + aux_tensors, + fastdiv_mods, + ) + + #### APPLY MASK (after score_mod, matching forward pass order) + check_m_boundary = (m_block + 1) * self.tile_m > seqlen.seqlen_q + mask_fn( + tSrS_t2r, + m_block=m_block, + is_full_block=is_full_block, + check_m_boundary=check_m_boundary, + ) + + num_stages = cute.size(tScS_t2r, mode=[1]) + + # --------------------------------------------- + #### P = exp(S - LSE) + # --------------------------------------------- + lane_idx = cute.arch.lane_idx() + tSrP_r2t_f32 = cute.make_fragment(tScP_r2t.shape, Float32) # 64 + tSrP_r2t = cute.recast_tensor(tSrP_r2t_f32, self.q_dtype) + for stage in cutlass.range_constexpr(num_stages): + tSrS_cur = tSrS_t2r[None, stage, 0, 0] + tSsLSE_cur = tSsLSE[None, stage, 0, 0, consumer_state_LSE.index] + if const_expr(not self.shuffle_LSE): + if const_expr(stage > 0 or not prefetch_LSE): + cute.autovec_copy(tSsLSE_cur, tSrLSE_s2r) + tSrLSE = tSrLSE_s2r + else: + tSrLSE = tSsLSE_cur[lane_idx] + for v in cutlass.range_constexpr(cute.size(tSrS_t2r, mode=[0]) // 2): + if const_expr(not self.shuffle_LSE): + lse_pair = (tSrLSE[2 * v], tSrLSE[2 * v + 1]) + else: + lse_pair = ( + utils.shuffle_sync(tSrLSE, offset=2 * v), + utils.shuffle_sync(tSrLSE, offset=2 * v + 1), + ) + tSrS_cur[2 * v], tSrS_cur[2 * v + 1] = utils.fma_packed_f32x2( + ((tSrS_cur[2 * v], tSrS_cur[2 * v + 1])), + (softmax_scale_log2, softmax_scale_log2), + (-lse_pair[0], -lse_pair[1]), + ) + tSrS_cur[2 * v] = cute.math.exp2(tSrS_cur[2 * v], fastmath=True) + tSrS_cur[2 * v + 1] = cute.math.exp2(tSrS_cur[2 * v + 1], fastmath=True) + utils.cvt_f16(tSrS_cur, tSrP_r2t[None, stage, 0, 0]) + if const_expr(stage == 0): + cute.arch.fence_view_async_tmem_load() + # Without this barrier, we could have 1 warp writing to P in tmem while + # another warp is still reading S from tmem. + self.compute_sync_barrier.arrive_and_wait() + cute.copy( + thr_copy_r2t, + tSrP_r2t_f32[None, stage, None, None], + tStP_r2t[None, stage, None, None], + ) + + cute.arch.fence_view_async_tmem_store() + self.compute_sync_barrier.arrive_and_wait() + + with cute.arch.elect_one(): + pipeline_S_P.consumer_release(consumer_state_S_P_dP) + # pipeline_S_P.sync_object_empty.arrive(0, pipeline_S_P.consumer_mask) + pipeline_LSE.consumer_release(consumer_state_LSE) + # consumer_state_S_P_dP.advance() + consumer_state_LSE.advance() + + # --------------------------------------------- + # dS.T = P.T * (dP.T - D) + # --------------------------------------------- + pipeline_dPsum.consumer_wait(consumer_state_dPsum) + + pipeline_dP.consumer_wait(consumer_state_S_P_dP) + # pipeline_dP.sync_object_full.wait(0, consumer_phase_S_P_dP) + consumer_state_S_P_dP.advance() + # consumer_phase_S_P_dP ^= 1 + + ##### dS.T = P.T * (dP.T - Psum) + for stage in cutlass.range_constexpr(num_stages): + tdPrdP_t2r = cute.make_fragment(tScS_t2r[None, 0, None, None].shape, Float32) + cute.copy(thr_copy_t2r, tdPtdP_t2r[None, stage, None, None], tdPrdP_t2r) + cute.arch.fence_view_async_tmem_load() + self.compute_sync_barrier.arrive_and_wait() + tdPrdP_cur = tdPrdP_t2r[None, 0, 0] + tSrS_cur = tSrS_t2r[None, stage, 0, 0] + tSsdPsum_cur = tSsdPsum[None, stage, 0, 0, consumer_state_dPsum.index] + if const_expr(not self.shuffle_dPsum): + tSrdPsum = cute.make_fragment_like(tSsdPsum_cur, Float32) + cute.autovec_copy(tSsdPsum_cur, tSrdPsum) + else: + tSrdPsum = tSsdPsum_cur[lane_idx] + for v in cutlass.range_constexpr(cute.size(tdPrdP_t2r, mode=[0]) // 2): + if const_expr(not self.shuffle_dPsum): + dPsum_pair = (tSrdPsum[2 * v], tSrdPsum[2 * v + 1]) + else: + dPsum_pair = ( + utils.shuffle_sync(tSrdPsum, offset=2 * v), + utils.shuffle_sync(tSrdPsum, offset=2 * v + 1), + ) + tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1] = utils.sub_packed_f32x2( + (tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1]), dPsum_pair + ) + tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1] = utils.mul_packed_f32x2( + (tSrS_cur[2 * v], tSrS_cur[2 * v + 1]), + (tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1]), + ) + + if const_expr(self.score_mod_bwd is not None): + tSrS_pre_cur = tSrS_pre[None, stage, 0, 0] + cS_bwd = cute.make_identity_tensor((self.tile_n, self.tile_m)) + cS_bwd = cute.domain_offset( + (n_block * self.tile_n, m_block * self.tile_m), cS_bwd + ) + tScS_bwd = thr_mma_S.partition_C(cS_bwd) + tScS_idx_bwd = thr_copy_t2r.partition_D(tScS_bwd) + tScS_idx_cur = tScS_idx_bwd[None, stage, 0, 0] + self.apply_score_mod_bwd( + tdPrdP_cur, + tSrS_pre_cur, + tScS_idx_cur, + batch_idx, + head_idx, + softmax_scale, + seqlen, + aux_tensors, + fastdiv_mods, + ) + # Zero out OOB positions (kv_idx >= seqlen_k) after score_mod_bwd + for i in cutlass.range(cute.size(tdPrdP_cur), unroll_full=True): + kv_idx = tScS_idx_cur[i][0] + tdPrdP_cur[i] = 0.0 if kv_idx >= seqlen.seqlen_k else tdPrdP_cur[i] + + tdPrdS_cvt = cute.make_fragment_like(tdPrdP_cur, self.ds_dtype) + utils.cvt_f16(tdPrdP_cur, tdPrdS_cvt) + if const_expr(stage == 0): + pipeline_dS.producer_acquire(producer_state_dS) + cute.autovec_copy(tdPrdS_cvt, tRS_sdS[None, stage]) + if const_expr(not self.use_smem_dS_for_mma_dK): + tdPrdS_r2t_f32 = cute.recast_tensor(tdPrdS_cvt, Float32) + cute.copy(thr_copy_r2t, tdPrdS_r2t_f32, tdPtdS_r2t[None, stage, 0, 0]) + + if const_expr(not self.use_smem_dS_for_mma_dK): + cute.arch.fence_view_async_tmem_store() + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + self.compute_sync_barrier.arrive_and_wait() + + # with cute.arch.elect_one(): + # The mma warp no longer waits for dP (it waits for dS), so we don't have to arrive + # pipeline_dP.sync_object_empty.arrive(0, pipeline_dP.consumer_mask) + pipeline_dPsum.consumer_release(consumer_state_dPsum) + consumer_state_dPsum.advance() + with cute.arch.elect_one(): + pipeline_dS.producer_commit(producer_state_dS) + producer_state_dS.advance() + + # Epilogue + # Run epilogue if we processed any m_blocks for this n_block + if process_tile: + if const_expr(not self.use_tma_store): + consumer_state_dKV = self.epilogue_dKV( + dp_idx, + warp_idx, + batch_idx, + head_idx, + n_block, + seqlen, + thr_mma_dV, + thr_mma_dK, + tdVtdV, + tdKtdK, + mdV, + mdK, + pipeline_dKV, + consumer_state_dKV, + softmax_scale, + ) + else: + thr_copy_r2s_dKV = tiled_copy_r2s_dKV.get_slice(dp_idx) + #### STORE dV + consumer_state_dKV = self.epilogue_dK_or_dV_tma( + dp_idx, + batch_idx, + head_idx, + n_block, + seqlen, + thr_mma_dV, + tdVtdV, + mdV_tma_tensor, + sdV, + tma_atom_dV, + thr_copy_r2s_dKV, + pipeline_dKV, + consumer_state_dKV, + None, # Don't scale + int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id + mdV_semaphore, + ) + #### STORE dK + consumer_state_dKV = self.epilogue_dK_or_dV_tma( + dp_idx, + batch_idx, + head_idx, + n_block, + seqlen, + thr_mma_dK, + tdKtdK, + mdK_tma_tensor, + sdK, + tma_atom_dK, + thr_copy_r2s_dKV, + pipeline_dKV, + consumer_state_dKV, + softmax_scale if const_expr(not self.dKV_postprocess) else None, + int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id + mdK_semaphore, + ) + # Zero dK/dV for empty tiles (local attention or block sparsity) + # When total_m_block_cnt == 0 for block sparsity, no Q tiles contribute to this KV tile + if const_expr(not self.dKV_postprocess): + should_zero_dKV = False + if const_expr(self.is_local or self.is_varlen_q): + should_zero_dKV = m_block_min >= m_block_max + if const_expr(self.use_block_sparsity): + # For block sparsity, zero when no m_blocks contribute to this n_block + if not process_tile: + should_zero_dKV = True + + if should_zero_dKV: + # like other epis, currently assumes hdim == hdimv + gmem_tiled_copy_zero_dKV = copy_utils.tiled_copy_2d( + self.dk_dtype, + self.tile_hdim, + 128, # num_threads + ) + gmem_thr_copy_zero_dKV = gmem_tiled_copy_zero_dKV.get_slice(dp_idx) + mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3)[None, None, head_idx] + mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3)[None, None, head_idx] + gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) + gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0)) + tdKgdK = gmem_thr_copy_zero_dKV.partition_D(gdK) + tdVgdV = gmem_thr_copy_zero_dKV.partition_D(gdV) + assert tdKgdK.shape[2] == 1 + assert tdVgdV.shape[2] == 1 + cdKV = cute.make_identity_tensor((self.tile_n, self.tile_hdim)) + tdKVcdKV = gmem_thr_copy_zero_dKV.partition_D(cdKV) + zero = cute.make_fragment_like(tdKgdK[None, 0, 0]) + zero.fill(0.0) + if tidx < 128: + for i in cutlass.range_constexpr(tdKgdK.shape[1]): + row_idx = tdKVcdKV[0, i, 0][0] + if row_idx < seqlen.seqlen_k - self.tile_n * n_block: + cute.copy(gmem_tiled_copy_zero_dKV, zero, tdKgdK[None, i, 0]) + else: + for i in cutlass.range_constexpr(tdVgdV.shape[1]): + row_idx = tdKVcdKV[0, i, 0][0] + if row_idx < seqlen.seqlen_k - self.tile_n * n_block: + cute.copy(gmem_tiled_copy_zero_dKV, zero, tdVgdV[None, i, 0]) + + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit + def dQacc_reduce( + self, + mdQaccum: cute.Tensor, + sdQaccum: cute.Tensor, + thr_mma_dQ: cute.core.ThrMma, + tdQtdQ: cute.Tensor, + pipeline_dQ: PipelineAsync, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + mdQ_semaphore: Optional[cute.Tensor], + blocksparse_tensors: Optional[BlockSparseTensors] = None, + ): + num_reduce_threads = cute.arch.WARP_SIZE * len(self.reduce_warp_ids) + tidx = cute.arch.thread_idx()[0] % num_reduce_threads + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx() % len(self.reduce_warp_ids)) + is_tma_warp = warp_idx == 0 + # TMEM -> RMEM + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol)), Float32 + ) + thr_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ).get_slice(tidx) + tdQtdQ_t2r = thr_copy_t2r.partition_S(tdQtdQ) + tdQcdQ = thr_mma_dQ.partition_C(cute.make_identity_tensor(self.mma_tiler_dsk[:2])) + tdQrdQ_t2r_shape = thr_copy_t2r.partition_D(tdQcdQ).shape + assert cute.size(tdQrdQ_t2r_shape, mode=[1]) == self.dQaccum_reduce_stage, ( + "dQaccum reduce stage mismatch" + ) + + thr_copy_dQaccum_r2s = copy_utils.tiled_copy_1d( + self.dqaccum_dtype, num_reduce_threads, num_copy_elems=128 // self.dqaccum_dtype.width + ).get_slice(tidx) + tdQsdQ = thr_copy_dQaccum_r2s.partition_D(sdQaccum) + + read_flag = const_expr(not self.deterministic) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + dQ_consumer_state = pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, 1 + ) + dQ_tma_store_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.sdQaccum_stage + ) + while work_tile.is_valid_tile: + n_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + m_block_min, m_block_max = block_info.get_m_block_min_max( + seqlen, n_block // self.cluster_shape_mnk[0] + ) + if const_expr(not seqlen.has_cu_seqlens_q): + mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] + else: + mdQaccum_cur = cute.domain_offset( + (seqlen.padded_offset_q * self.tile_hdim,), mdQaccum[None, head_idx] + ) + gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,)) + # (M * K / STAGE, STAGE, _) + gdQaccum = cute.flat_divide( + gdQaccum_, (self.tile_m * self.tile_hdim // self.dQaccum_reduce_stage,) + ) + + if const_expr(self.deterministic): + mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx] + + delay_semaphore_release = self.is_causal + n_block_global_max = cute.ceil_div(seqlen.seqlen_k, self.tile_n) + + # some tiles might be empty due to block sparsity + if const_expr(self.use_block_sparsity): + ( + curr_q_cnt, + curr_q_idx, + curr_full_cnt, + curr_full_idx, + loop_count, + ) = get_block_sparse_iteration_info_bwd( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, + ) + process_tile = loop_count > Int32(0) + else: + process_tile = ( + const_expr(not self.is_local and not self.is_varlen_q) + or m_block_min < m_block_max + ) + loop_count = m_block_max - m_block_min + + # dQacc_reduce mainloop + # Block sparsity: iterate over sparse m_block count and derive actual m_block + # from Q_IDX/FULL_Q_IDX tensors. Dense: iterate m_block_min..m_block_max directly. + for iter_idx in cutlass.range(loop_count, unroll=1): + if const_expr(self.use_block_sparsity): + m_block, _ = get_m_block_from_iter_bwd( + iter_idx, + curr_q_cnt, + curr_q_idx, + curr_full_cnt, + curr_full_idx, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, + ) + if m_block_max > 0: + m_block = cutlass.min(m_block, m_block_max - 1) + else: + m_block = m_block_min + iter_idx + pipeline_dQ.consumer_wait(dQ_consumer_state) + # TMEM -> RMEM + tdQrdQ_t2r = cute.make_fragment(tdQrdQ_t2r_shape, Float32) + cute.copy(thr_copy_t2r, tdQtdQ_t2r, tdQrdQ_t2r) + cute.arch.fence_view_async_tmem_load() + cute.arch.sync_warp() + with cute.arch.elect_one(): + pipeline_dQ.consumer_release(dQ_consumer_state) + dQ_consumer_state.advance() + + gdQaccum_cur = gdQaccum[None, None, m_block] + + for stage in cutlass.range_constexpr(cute.size(tdQrdQ_t2r, mode=[1])): # 4 + smem_idx = dQ_tma_store_producer_state.index + tdQsdQ_r2s = tdQsdQ[None, None, smem_idx] + tdQrdQ_r2s = cute.make_tensor( + tdQrdQ_t2r[None, stage, None, None].iterator, tdQsdQ_r2s.shape + ) + cute.copy(thr_copy_dQaccum_r2s, tdQrdQ_r2s, tdQsdQ_r2s) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + # semaphore acquire + if const_expr(self.deterministic and stage == 0): + if const_expr(self.spt): + if const_expr( + self.is_causal or block_info.window_size_right is not None + ): + n_idx_right = ( + (m_block + 1) * self.tile_m + seqlen.seqlen_k - seqlen.seqlen_q + ) + if const_expr(block_info.window_size_right is not None): + n_idx_right += block_info.window_size_right + n_block_max_for_m_block = min( + n_block_global_max, + cute.ceil_div(n_idx_right, self.tile_n), + ) + else: + n_block_max_for_m_block = n_block_global_max + lock_value = n_block_max_for_m_block - 1 - n_block + else: + lock_value = n_block + barrier.wait_eq( + mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, lock_value + ) + self.reduce_sync_barrier.arrive_and_wait() + # Copy from shared memory to global memory + if is_tma_warp: + with cute.arch.elect_one(): + copy_utils.cpasync_reduce_bulk_add_f32( + sdQaccum[None, smem_idx].iterator, + gdQaccum_cur[None, stage].iterator, + self.tma_copy_bytes["dQ"] // 1, + ) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(self.sdQaccum_stage - 1, read=read_flag) + self.reduce_sync_barrier.arrive_and_wait() + dQ_tma_store_producer_state.advance() + # Directly add to gmem, much slower + # tdQgdQ = thr_copy_dQaccum_r2s.partition_D(gdQaccum[None, stage, m_block]) + # assert cute.size(tdQrdQ_r2s) == cute.size(tdQgdQ) + # for i in cutlass.range(cute.size(tdQrdQ_r2s) // 4, unroll_full=True): + # copy_utils.atomic_add_fp32x4( + # tdQrdQ_r2s[4 * i], + # tdQrdQ_r2s[4 * i + 1], + # tdQrdQ_r2s[4 * i + 2], + # tdQrdQ_r2s[4 * i + 3], + # utils.elem_pointer(tdQgdQ, 4 * i), + # ) + # semaphore release for prior m_block + if const_expr(self.deterministic and stage == 0 and delay_semaphore_release): + if m_block > m_block_min: + barrier.arrive_inc( + mdQ_semaphore_cur[(m_block - 1, None)].iterator, tidx, 0, 1 + ) + + # semaphore release + # NOTE: arrive_inc calls red_release which issues membar + if const_expr(self.deterministic and not delay_semaphore_release): + if is_tma_warp: + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + self.reduce_sync_barrier.arrive_and_wait() + barrier.arrive_inc(mdQ_semaphore_cur[m_block, None].iterator, tidx, 0, 1) + + if const_expr(not self.is_local) or m_block_min < m_block_max: + if is_tma_warp: + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + self.reduce_sync_barrier.arrive_and_wait() + # final semaphore release + if const_expr(self.deterministic and delay_semaphore_release): + barrier.arrive_inc( + mdQ_semaphore_cur[(m_block_max - 1, None)].iterator, tidx, 0, 1 + ) + + if const_expr( + self.deterministic and not self.spt and block_info.window_size_left is not None + ): + m_block_global_max = cute.ceil_div(seqlen.seqlen_q, self.tile_m) + for m_block in cutlass.range(m_block_max, m_block_global_max, unroll=1): + barrier.arrive_inc(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, 1) + + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit + def epilogue_dKV( + self, + tidx: Int32, + warp_idx: Int32, + batch_idx: Int32, + head_idx: Int32, + n_block: Int32, + seqlen, + thr_mma_dV: cute.core.ThrMma, + thr_mma_dK: cute.core.ThrMma, + tdVtdV: cute.Tensor, + tdKtdK: cute.Tensor, + mdV: cute.Tensor, + mdK: cute.Tensor, + pipeline_dKV: PipelineAsync, + consumer_state_dKV: cutlass.pipeline.PipelineState, + softmax_scale: Float32, + ): + wg_idx = ( + cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids)) + ) // 128 + num_wg = cute.arch.WARP_SIZE * len(self.compute_warp_ids) // 128 + + assert self.qhead_per_kvhead == 1, "This epilogue path is only for MHA" + mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3)[None, None, head_idx] + mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3)[None, None, head_idx] + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(16)), Float32 + ) + + # dV + pipeline_dKV.consumer_wait(consumer_state_dKV) + + tiled_tmem_ld_dV = tcgen05.make_tmem_copy(tmem_load_atom, tdVtdV) + thr_tmem_ld_dV = tiled_tmem_ld_dV.get_slice(tidx) + + tdVtdV_t2r_p = thr_tmem_ld_dV.partition_S(tdVtdV) + tdVtdV_t2r = self.split_wg(tdVtdV_t2r_p, wg_idx, num_wg) + + cdV = cute.make_identity_tensor((self.mma_tiler_pdo[0], self.mma_tiler_pdo[1])) + tdVcdV = thr_mma_dV.partition_C(cdV) + tdVcdV_tensor = cute.make_tensor(tdVcdV.iterator, tdVcdV.layout) + + tdVcdV_t2r_p = thr_tmem_ld_dV.partition_D(tdVcdV_tensor) + tdVcdV_t2r = self.split_wg(tdVcdV_t2r_p, wg_idx, num_wg) + tdVrdV_t2r = cute.make_fragment(tdVcdV_t2r.shape, Float32) + + cute.copy(thr_tmem_ld_dV, tdVtdV_t2r, tdVrdV_t2r) + cute.arch.fence_view_async_tmem_load() + + universal_copy_bits = 128 + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dv_dtype, + num_bits_per_copy=universal_copy_bits, + ) + tiled_gmem_store_dV = cute.make_tiled_copy( + atom_universal_copy, + layout_tv=tiled_tmem_ld_dV.layout_dst_tv_tiled, + tiler_mn=tiled_tmem_ld_dV.tiler_mn, + ) + + tdVrdV_r2s = cute.make_fragment(tdVrdV_t2r.shape, self.dv_dtype) + for i in cutlass.range_constexpr(cute.size(tdVrdV_t2r, mode=[1])): + dV_vec = tdVrdV_t2r[(None, i, 0, 0)].load() + tdVrdV_r2s[(None, i, 0, 0)].store(dV_vec.to(self.dv_dtype)) + + gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (None, 0)) + gdV_tile = gdV[None, None, n_block] + + tdVgdV = thr_mma_dV.partition_C(gdV_tile) + tdVgdV_r2g_p = thr_tmem_ld_dV.partition_D(tdVgdV) + tdVgdV_r2g = self.split_wg(tdVgdV_r2g_p, wg_idx, num_wg) + + if tidx < seqlen.seqlen_k - self.tile_n * n_block: + cute.copy(tiled_gmem_store_dV, tdVrdV_r2s, tdVgdV_r2g) + + cute.arch.sync_warp() + with cute.arch.elect_one(): + pipeline_dKV.consumer_release(consumer_state_dKV) + consumer_state_dKV.advance() + + # dK + pipeline_dKV.consumer_wait(consumer_state_dKV) + + tiled_tmem_ld_dK = tcgen05.make_tmem_copy(tmem_load_atom, tdKtdK) + thr_tmem_ld_dK = tiled_tmem_ld_dK.get_slice(tidx) + + tdKtdK_t2r_p = thr_tmem_ld_dK.partition_S(tdKtdK) + tdKtdK_t2r = self.split_wg(tdKtdK_t2r_p, wg_idx, num_wg) + + cdK = cute.make_identity_tensor((self.mma_tiler_dsq[0], self.mma_tiler_dsq[1])) + tdKcdK = thr_mma_dK.partition_C(cdK) + tdKcdK_tensor = cute.make_tensor(tdKcdK.iterator, tdKcdK.layout) + + tdKcdK_t2r_p = thr_tmem_ld_dK.partition_D(tdKcdK_tensor) + tdKcdK_t2r = self.split_wg(tdKcdK_t2r_p, wg_idx, num_wg) + tdKrdK_t2r = cute.make_fragment(tdKcdK_t2r.shape, Float32) + + cute.copy(tiled_tmem_ld_dK, tdKtdK_t2r, tdKrdK_t2r) + cute.arch.fence_view_async_tmem_load() + + universal_copy_bits = 128 + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dk_dtype, + num_bits_per_copy=universal_copy_bits, + ) + + tiled_gmem_store_dK = cute.make_tiled_copy( + atom_universal_copy, + layout_tv=tiled_tmem_ld_dK.layout_dst_tv_tiled, + tiler_mn=tiled_tmem_ld_dK.tiler_mn, + ) + + tdKrdK_r2s = cute.make_fragment(tdKrdK_t2r.shape, self.dk_dtype) + + for i in cutlass.range_constexpr(cute.size(tdKrdK_t2r, mode=[1])): + dK_vec = tdKrdK_t2r[(None, i, 0, 0)].load() * softmax_scale + tdKrdK_r2s[(None, i, 0, 0)].store(dK_vec.to(self.dk_dtype)) + + gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdimv), (None, 0)) + gdK_tile = gdK[None, None, n_block] + + tdKgdK = thr_mma_dK.partition_C(gdK_tile) + tdKgdK_r2g_p = thr_tmem_ld_dK.partition_D(tdKgdK) + tdKgdK_r2g = self.split_wg(tdKgdK_r2g_p, wg_idx, num_wg) + + if tidx < seqlen.seqlen_k - self.tile_n * n_block: + cute.copy(tiled_gmem_store_dK, tdKrdK_r2s, tdKgdK_r2g) + + cute.arch.sync_warp() + with cute.arch.elect_one(): + pipeline_dKV.consumer_release(consumer_state_dKV) + consumer_state_dKV.advance() + return consumer_state_dKV + + @cute.jit + def epilogue_dK_or_dV_tma( + self, + tidx: Int32, + batch_idx: Int32, + head_idx: Int32, + n_block: Int32, + seqlen, + thr_mma: cute.core.ThrMma, + tdKVtdKV: cute.Tensor, + mdKV: cute.Tensor, + sdKV: cute.Tensor, + tma_atom_dKV: cute.CopyAtom, + thr_copy_r2s_dKV: cute.TiledCopy, + pipeline_dKV: PipelineAsync, + consumer_state_dKV: cutlass.pipeline.PipelineState, + scale: Optional[Float32], + barrier_id: Int32, + mdKV_semaphore: Optional[cute.Tensor], + ) -> cutlass.pipeline.PipelineState: + # assumes mma_tiler_pdo = mma_tiler_dsq = (tile_n, head_dim) + # head_dim = head_dim_v, dk_dtype = dv_dtype + num_compute_threads = cute.arch.WARP_SIZE * len(self.compute_warp_ids) + wg_idx = (cute.arch.thread_idx()[0] % num_compute_threads) // 128 + num_wg = num_compute_threads // 128 + leader_warp = (cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4) == 0 + + if const_expr(not self.dKV_postprocess): + sdKV = sdKV[None, None, wg_idx] # (tile_n, 64) for bf16 + else: + sdKV = sdKV[None, wg_idx] # (tile_n * 32) for fp32 + + # (8, tile_n / 128, 64 / 8) = (8, 1, 8) or (4, tile_n * 32 / (128 * 4)) = (4, 8) + tdKVsdKV_r2s = thr_copy_r2s_dKV.partition_D(sdKV) + + head_idx_kv = head_idx // self.qhead_per_kvhead + if const_expr(not self.dKV_postprocess): + assert not seqlen.has_cu_seqlens_k, "varlen uses non tma store path" + mdKV_cur = mdKV[None, None, head_idx_kv, batch_idx] # (seqlen, hdim) + gdKV_p = cute.local_tile( + mdKV_cur, (self.tile_n, self.tile_hdim), (n_block, 0) + ) # (tile_n, hdim) + gdKV = self.split_wg(gdKV_p, wg_idx, num_wg) # (tile_n, hdim / 2) + gdKV_epi = cute.local_tile( + gdKV, self.sdKV_epi_tile, (0, None) + ) # (tile_n, 64, epi_stage = (hdim / 2) / 64) + else: + if const_expr(not seqlen.has_cu_seqlens_k): + mdKV_cur = mdKV[None, head_idx_kv, batch_idx] # (seqlen * hdim) + else: + mdKV_cur = cute.domain_offset( + (seqlen.padded_offset_k * self.tile_hdim,), mdKV[None, head_idx_kv] + ) + gdKV_p = cute.local_tile( + mdKV_cur, (self.tile_n * self.tile_hdim,), (n_block,) + ) # (tile_n * hdim) + gdKV = cute.logical_divide(gdKV_p, (self.tile_n * self.tile_hdim // num_wg,))[ + ((None, wg_idx),) + ] # (tile_n * hdim / 2) + gdKV_epi = cute.flat_divide( + gdKV, (self.sdKV_flat_epi_tile,) + ) # (tile_n * hdim / 2 / epi_stage, epi_stage) + + deterministic_KV = self.deterministic and self.qhead_per_kvhead > 1 + if const_expr(deterministic_KV): + mdKV_semaphore_cur = mdKV_semaphore[n_block, None, head_idx_kv, batch_idx] + + if const_expr(not self.dKV_postprocess): + tdKVsdKV, tdKVgdKV = cpasync.tma_partition( + tma_atom_dKV, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sdKV, 0, 2), + cute.group_modes(gdKV_epi, 0, 2), + ) # (TMA) and (TMA, EPI_STAGE) + assert len(tdKVsdKV.shape) == 1, "Wrong rank for SMEM fragment tdKVsdKV" + assert len(tdKVgdKV.shape) == 2, "Wrong rank for GMEM fragment tdKVgdKV" + num_epi_stages = cute.size(tdKVgdKV.shape[1]) + assert num_epi_stages == self.num_epi_stages, "Epi stage calculation is wrong" + else: + num_epi_stages = self.num_epi_stages + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 + ) + + read_flag = const_expr(not deterministic_KV) + + pipeline_dKV.consumer_wait(consumer_state_dKV) + + # semaphore acquire + if const_expr(deterministic_KV): + barrier.wait_eq( + mdKV_semaphore_cur.iterator, tidx, wg_idx, head_idx % self.qhead_per_kvhead + ) + cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128) + + for epi_stage in cutlass.range_constexpr(num_epi_stages): + # TMEM -> RMEM -- setup + thr_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdKVtdKV).get_slice(tidx) + tdKVtdKV_t2r_p = thr_copy_t2r.partition_S(tdKVtdKV) + tdKVtdKV_t2r = self.split_wg(tdKVtdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0] + if const_expr(num_epi_stages > 1): + tdKVtdKV_t2r = tdKVtdKV_t2r[None, epi_stage] + + cdKV = cute.make_identity_tensor((self.tile_n, self.tile_hdim)) + tdKVcdKV = thr_mma.partition_C(cdKV) + tdKVcdKV_t2r_p = thr_copy_t2r.partition_D(tdKVcdKV) + tdKVcdKV_t2r = self.split_wg(tdKVcdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0] + if const_expr(num_epi_stages > 1): + tdKVcdKV_t2r = tdKVcdKV_t2r[None, epi_stage] + + tdKVrdKV_t2r = cute.make_fragment(tdKVcdKV_t2r.shape, Float32) + + assert cute.size(tdKVrdKV_t2r) == cute.size(tdKVtdKV_t2r) // cute.arch.WARP_SIZE, ( + "RMEM<->TMEM fragment size mismatch" + ) + + # TMEM -> RMEM -- copy and fence + cute.copy(thr_copy_t2r, tdKVtdKV_t2r, tdKVrdKV_t2r) + cute.arch.fence_view_async_tmem_load() + + # RMEM -- scale and convert + if const_expr(scale is not None): + for i in cutlass.range(cute.size(tdKVrdKV_t2r.shape) // 2, unroll_full=True): + tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1] = utils.mul_packed_f32x2( + (tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1]), (scale, scale) + ) + tdKVrdKV = cute.make_fragment(tdKVrdKV_t2r.shape, self.dv_dtype) # (32 columns) + tdKVrdKV.store(tdKVrdKV_t2r.load().to(self.dv_dtype)) + + # RMEM -> SMEM -- copy, fence and barrier + tdKVrdKV_r2s = cute.make_tensor(tdKVrdKV.iterator, tdKVsdKV_r2s.shape) + cute.copy(thr_copy_r2s_dKV, tdKVrdKV_r2s, tdKVsdKV_r2s) + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128) + + # SMEM -> GMEM + if leader_warp: + if const_expr(not self.dKV_postprocess): + cute.copy(tma_atom_dKV, tdKVsdKV, tdKVgdKV[None, epi_stage]) + else: + with cute.arch.elect_one(): + copy_utils.cpasync_reduce_bulk_add_f32( + sdKV.iterator, + gdKV_epi[None, epi_stage].iterator, + self.tma_copy_bytes["dKacc"], + ) + if const_expr(epi_stage < num_epi_stages - 1): + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + cute.arch.barrier_arrive( + barrier_id=barrier_id + wg_idx, number_of_threads=128 + cute.arch.WARP_SIZE + ) + + # Barrier since all warps need to wait for SMEM to be freed + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.barrier( + barrier_id=barrier_id + wg_idx, number_of_threads=128 + cute.arch.WARP_SIZE + ) + + # semaphore release + # NOTE: arrive_inc calls red_release which issues membar + if const_expr(deterministic_KV): + if leader_warp: + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) + cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128) + barrier.arrive_inc(mdKV_semaphore_cur.iterator, tidx, wg_idx, 1) + + cute.arch.sync_warp() + with cute.arch.elect_one(): + pipeline_dKV.consumer_release(consumer_state_dKV) + consumer_state_dKV.advance() + return consumer_state_dKV diff --git a/flash_attn/cute/flash_bwd_sm90.py b/flash_attn/cute/flash_bwd_sm90.py new file mode 100644 index 00000000000..377a66a4385 --- /dev/null +++ b/flash_attn/cute/flash_bwd_sm90.py @@ -0,0 +1,1708 @@ +import math +from typing import Callable, Optional, Type +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.utils.hopper_helpers as sm90_utils_basic +from cutlass.cute.nvgpu import cpasync, warpgroup +from cutlass.cute.arch import ProxyKind, SharedSpace +from cutlass.cute import FastDivmodDivisor +from cutlass import Float32, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum + +from flash_attn.cute import hopper_helpers as sm90_utils +from flash_attn.cute import utils +from flash_attn.cute import copy_utils +from flash_attn.cute.hopper_helpers import gemm_zero_init, gemm_w_idx +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute import pipeline +from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, ParamsBase +from flash_attn.cute.named_barrier import NamedBarrierFwd, NamedBarrierBwd +from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( + get_total_q_block_count_bwd, + produce_block_sparse_q_loads_bwd_sm90, + consume_block_sparse_mma_bwd_sm90, + dQaccum_store_block_sparse_bwd_sm90, +) + + +def mma_partition_fragment_AB( + thr_mma: cute.core.ThrMma, sA: Optional[cute.Tensor], sB: Optional[cute.Tensor], swap_AB: bool +): + if const_expr(not swap_AB): + return ( + thr_mma.make_fragment_A(thr_mma.partition_A(sA)) if sA is not None else None, + thr_mma.make_fragment_B(thr_mma.partition_B(sB)) if sB is not None else None, + ) + else: + return ( + thr_mma.make_fragment_B(thr_mma.partition_B(sA)) if sA is not None else None, + thr_mma.make_fragment_A(thr_mma.partition_A(sB)) if sB is not None else None, + ) + + +class FlashAttentionBackwardSm90: + arch = 90 + + def __init__( + self, + dtype: Type[cutlass.Numeric], + head_dim: int, + head_dim_v: Optional[int] = None, + qhead_per_kvhead: int = 1, + is_causal: bool = False, + tile_m: int = 64, + tile_n: int = 128, + Q_stage: int = 2, + dO_stage: int = 2, + PdS_stage: int = 2, + SdP_swapAB: bool = False, + dKV_swapAB: bool = False, + dQ_swapAB: bool = False, + AtomLayoutMSdP: int = 1, + AtomLayoutNdKV: int = 2, + AtomLayoutMdQ: int = 1, + num_threads: int = 384, + V_in_regs: bool = False, + score_mod: cutlass.Constexpr | None = None, + score_mod_bwd: cutlass.Constexpr | None = None, + mask_mod: cutlass.Constexpr | None = None, + has_aux_tensors: cutlass.Constexpr = False, + subtile_factor: cutlass.Constexpr[int] = 1, + ): + self.dtype = dtype + # padding head_dim to a multiple of 16 as k_block_size + hdim_multiple_of = 16 + self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + head_dim_v = head_dim_v if head_dim_v is not None else head_dim + self.same_hdim_kv = head_dim == head_dim_v + self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) + # Can save registers (and hence be faster) if we don't have to check hdim predication + self.check_hdim_oob = head_dim != self.tile_hdim + self.check_hdim_v_oob = head_dim_v != self.tile_hdimv + self.qhead_per_kvhead = qhead_per_kvhead + self.is_causal = is_causal + self.is_local = False + self.tile_m = tile_m + self.tile_n = tile_n + self.num_threads = num_threads + self.Q_stage = Q_stage + self.dO_stage = dO_stage + self.PdS_stage = PdS_stage + assert self.dO_stage in [1, self.Q_stage] + assert self.PdS_stage in [1, self.Q_stage] + self.SdP_swapAB = SdP_swapAB + self.dKV_swapAB = dKV_swapAB + self.dQ_swapAB = dQ_swapAB + self.AtomLayoutMSdP = AtomLayoutMSdP + self.AtomLayoutNdKV = AtomLayoutNdKV + self.AtomLayoutMdQ = AtomLayoutMdQ + self.num_mma_warp_groups = (self.num_threads // 128) - 1 + self.mma_dkv_is_rs = ( + AtomLayoutMSdP == 1 + and AtomLayoutNdKV == self.num_mma_warp_groups + and SdP_swapAB + and not dKV_swapAB + ) + self.V_in_regs = V_in_regs + if qhead_per_kvhead > 1: + assert self.same_hdim_kv, "GQA backward requires head_dim == head_dim_v" + assert self.num_mma_warp_groups == 2, "GQA backward assumes 2 warp groups" + # These are tuned for speed + # Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share + # them and then shuffle to get the value whenever we need? This can reduce register + # pressure when SdP_swapAB, where each thread needs to keep statistics for (kBlockM / 4) + # rows. If !SdP_swapAB, each thread only needs to keep statistics for 2 rows. + # TODO: impl these for hdim 64 + self.shuffle_LSE = self.SdP_swapAB and self.tile_hdim <= 64 + self.shuffle_dPsum = self.SdP_swapAB and self.tile_hdim <= 64 + + self.score_mod = score_mod + self.score_mod_bwd = score_mod_bwd + self.mask_mod = mask_mod + self.has_aux_tensors = has_aux_tensors + self.subtile_factor = subtile_factor + if cutlass.const_expr(has_aux_tensors): + self.vec_size: cutlass.Constexpr = 1 + else: + self.vec_size: cutlass.Constexpr = 4 + self.qk_acc_dtype = Float32 + + @staticmethod + def can_implement( + dtype, + head_dim, + head_dim_v, + tile_m, + tile_n, + Q_stage, + num_threads, + V_in_regs=False, + ) -> bool: + if dtype not in [cutlass.Float16, cutlass.BFloat16]: + return False + if head_dim % 8 != 0: + return False + if head_dim_v % 8 != 0: + return False + if tile_n % 16 != 0: + return False + if num_threads % 32 != 0: + return False + if (tile_m * 2) % num_threads != 0: + return False + return True + + def _check_type( + self, + mQ_type: Type[cutlass.Numeric], + mK_type: Type[cutlass.Numeric], + mV_type: Type[cutlass.Numeric], + mdO_type: Type[cutlass.Numeric], + mLSE_type: Type[cutlass.Numeric], + mdPsum_type: Type[cutlass.Numeric], + mdQaccum_type: Type[cutlass.Numeric], + mdK_type: Type[cutlass.Numeric], + mdV_type: Type[cutlass.Numeric], + ): + # Get the data type and check if it is fp16 or bf16 + if const_expr(not (mQ_type == mK_type == mV_type == mdO_type)): + raise TypeError("All tensors must have the same data type") + if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]): + raise TypeError("Only Float16 or BFloat16 is supported") + if const_expr(mLSE_type not in [Float32]): + raise TypeError("LSE tensor must be Float32") + if const_expr(mdPsum_type not in [Float32]): + raise TypeError("dPsum tensor must be Float32") + if const_expr(mdQaccum_type not in [Float32]): + raise TypeError("dQaccum tensor must be Float32") + if const_expr(self.qhead_per_kvhead == 1): + if const_expr(not (mdK_type == mdV_type == mQ_type)): + raise TypeError("mdK and mdV tensors must have the same data type as mQ") + else: + if const_expr(not (mdK_type == mdV_type == Float32)): + raise TypeError("mdKaccum and mdVaccum tensors must have the data type Float32") + assert mQ_type == self.dtype + + def _setup_attributes(self): + self.sQ_layout, self.sK_layout, self.sV_layout, self.sdO_layout, self.sPdS_layout = [ + sm90_utils.make_smem_layout(self.dtype, LayoutEnum.ROW_MAJOR, shape, stage) + for shape, stage in [ + ((self.tile_m, self.tile_hdim), self.Q_stage), + ((self.tile_n, self.tile_hdim), None), + ((self.tile_n, self.tile_hdimv), None), + ((self.tile_m, self.tile_hdimv), self.dO_stage), + ((self.tile_m, self.tile_n), self.PdS_stage), + ] + ] + self.sdQaccum_layout = cute.make_layout( + (self.tile_m * self.tile_hdim // self.num_mma_warp_groups, self.num_mma_warp_groups) + ) + # dQaccum R->S + self.r2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128), + # thr_layout + cute.make_layout((self.num_threads_per_warp_group, self.num_mma_warp_groups)), + cute.make_layout(128 // Float32.width), # val_layout + ) + # dKVaccum for GQA epilogue - reuses sV+sK memory recast as f32 + self.sdKVaccum_layout = cute.make_layout( + (self.tile_n * self.tile_hdim // self.num_mma_warp_groups, self.num_mma_warp_groups) + ) + # dKVaccum R->S (same pattern as dQaccum but sized for tile_n) + self.r2s_tiled_copy_dKVaccum = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128), + cute.make_layout((self.num_threads_per_warp_group, self.num_mma_warp_groups)), + cute.make_layout(128 // Float32.width), + ) + + def _get_tiled_mma(self): + # S = Q @ K.T, dP = dO @ V.T + atom_layout_SdP = (self.AtomLayoutMSdP, self.num_mma_warp_groups // self.AtomLayoutMSdP) + tiler_mn_SdP = (self.tile_m // atom_layout_SdP[0], self.tile_n // atom_layout_SdP[1]) + tiled_mma_SdP = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K, + warpgroup.OperandMajorMode.K, + Float32, + atom_layout_mnk=(atom_layout_SdP if not self.SdP_swapAB else atom_layout_SdP[::-1]) + + (1,), + tiler_mn=tiler_mn_SdP if not self.SdP_swapAB else tiler_mn_SdP[::-1], + ) + # dV = P.T @ dO, dK = dS.T @ Q + atom_layout_dKV = (self.AtomLayoutNdKV, self.num_mma_warp_groups // self.AtomLayoutNdKV) + tiler_mn_dK = (self.tile_n // atom_layout_dKV[0], self.tile_hdim // atom_layout_dKV[1]) + tiler_mn_dV = (self.tile_n // atom_layout_dKV[0], self.tile_hdimv // atom_layout_dKV[1]) + tiled_mma_dK, tiled_mma_dV = [ + sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.MN + if not self.mma_dkv_is_rs + else warpgroup.OperandMajorMode.K, + warpgroup.OperandMajorMode.MN, + Float32, + atom_layout_mnk=(atom_layout_dKV if not self.dKV_swapAB else atom_layout_dKV[::-1]) + + (1,), + tiler_mn=tiler_mn_d if not self.dKV_swapAB else tiler_mn_d[::-1], + a_source=warpgroup.OperandSource.RMEM + if self.mma_dkv_is_rs + else warpgroup.OperandSource.SMEM, + ) + for tiler_mn_d in (tiler_mn_dK, tiler_mn_dV) + ] + # dQ = dS @ K + atom_layout_dQ = (self.AtomLayoutMdQ, self.num_mma_warp_groups // self.AtomLayoutMdQ) + tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1]) + tiled_mma_dQ = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K if not self.dQ_swapAB else warpgroup.OperandMajorMode.MN, + warpgroup.OperandMajorMode.MN if not self.dQ_swapAB else warpgroup.OperandMajorMode.K, + Float32, + atom_layout_mnk=(atom_layout_dQ if not self.dQ_swapAB else atom_layout_dQ[::-1]) + (1,), + tiler_mn=tiler_mn_dQ if not self.dQ_swapAB else tiler_mn_dQ[::-1], + ) + return tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ + + def _get_shared_storage_cls(self): + sQ_alignment = sK_alignment = sV_alighment = sdQaccum_alignment = sdO_alignment = 1024 + + sQ_struct, sK_struct, sV_struct, sdO_struct, sdQaccum_struct = [ + cute.struct.Align[cute.struct.MemRange[type, cute.cosize(layout)], alignment] + for (layout, type, alignment) in [ + (self.sQ_layout, self.dtype, sQ_alignment), + (self.sK_layout, self.dtype, sK_alignment), + (self.sV_layout, self.dtype, sV_alighment), + (self.sdO_layout, self.dtype, sdO_alignment), + (self.sdQaccum_layout, Float32, sdQaccum_alignment), + ] + ] + + cosize_sdS = cute.cosize(self.sPdS_layout) + cosize_sP = cute.cosize(self.sPdS_layout) if const_expr(not self.mma_dkv_is_rs) else 0 + sLSE_struct = cute.struct.Align[ + cute.struct.MemRange[Float32, cute.round_up(self.tile_m, 64) * self.Q_stage], 128 + ] + sdPsum_struct = cute.struct.Align[ + cute.struct.MemRange[Float32, cute.round_up(self.tile_m, 64) * self.dO_stage], 128 + ] + + @cute.struct + class SharedStorageQKV: + mbar_ptr_Q: cute.struct.MemRange[cutlass.Int64, self.Q_stage * 2] + mbar_ptr_dO: cute.struct.MemRange[cutlass.Int64, self.dO_stage * 2] + sLSE: sLSE_struct + sdPsum: sdPsum_struct + sQ: sQ_struct + sV: sV_struct + sK: sK_struct + sdO: sdO_struct + sP: cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024] + sdS: cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sdS], 1024] + sdQaccum: sdQaccum_struct + + return SharedStorageQKV + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mdO: cute.Tensor, + mLSE: cute.Tensor, + mdPsum: cute.Tensor, + mdQaccum: cute.Tensor, + mdK: cute.Tensor, + mdV: cute.Tensor, + softmax_scale: Float32, + stream: cuda.CUstream, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + softcap: Float32 | float | None = None, + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, + mdQ_semaphore: Optional[cute.Tensor] = None, + mdK_semaphore: Optional[cute.Tensor] = None, + mdV_semaphore: Optional[cute.Tensor] = None, + aux_tensors: Optional[list] = None, + blocksparse_tensors: Optional[BlockSparseTensors] = None, + ): + assert mdQ_semaphore is None and mdK_semaphore is None and mdV_semaphore is None, ( + "determinism not supported yet for Sm90" + ) + + self._check_type( + *( + t.element_type if t is not None else None + for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV) + ) + ) + + # Assume all strides are divisible by 128 bits except the last stride + # Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints) + new_stride = lambda t: ( + *( + cute.assume(s, divby=128 // t.element_type.width) + if not isinstance(s, int) or s != 0 + else s + for s in t.stride[:-1] + ), + t.stride[-1], + ) + mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + if t is not None + else None + for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV) + ] + + layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b) + mQ, mK, mV, mdO = [utils.select(t, layout_transpose) for t in (mQ, mK, mV, mdO)] + if const_expr(self.qhead_per_kvhead == 1): + mdK, mdV = [utils.select(t, layout_transpose) for t in (mdK, mdV)] + else: + accum_transpose = [2, 1, 0] # (b, n, s*h) -> (s*h, n, b) + mdK, mdV = [utils.select(t, accum_transpose) for t in (mdK, mdV)] + LSE_dPsum_dQaccum_transpose = [2, 1, 0] # (b, n, s) -> (s, n, b) + mLSE, mdPsum, mdQaccum = [ + utils.select(t, LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum) + ] + + tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ = self._get_tiled_mma() + + self.num_mma_threads = tiled_mma_SdP.size + assert self.num_mma_threads + 128 == self.num_threads + + self.num_threads_per_warp_group = 128 + self.num_producer_threads = 32 + + self.num_mma_regs = 240 + self.num_producer_regs = 24 + # self.num_mma_regs = 232 + # self.num_producer_regs = 40 + + self._setup_attributes() + SharedStorage = self._get_shared_storage_cls() + + self.tma_copy_bytes = { + name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1])) + for name, mX, layout in [ + ("Q", mQ, self.sQ_layout), + ("K", mK, self.sK_layout), + ("V", mV, self.sV_layout), + ("dO", mdO, self.sdO_layout), + ] + } + self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8 + self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8 + self.tma_copy_bytes["dQ"] = ( + self.tile_m * self.tile_hdim * Float32.width // 8 // self.num_mma_warp_groups + ) + self.tma_copy_bytes["dKacc"] = self.tile_n * self.tile_hdim * Float32.width // 8 + self.tma_copy_bytes["dVacc"] = self.tile_n * self.tile_hdimv * Float32.width // 8 + + tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + mQ, + cute.select(self.sQ_layout, mode=[0, 1]), + (self.tile_m, self.tile_hdim), + ) + tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + mK, + cute.select(self.sK_layout, mode=[0, 1]), + (self.tile_n, self.tile_hdim), + ) + tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + mV, + cute.select(self.sV_layout, mode=[0, 1]), + (self.tile_n, self.tile_hdimv), + ) + tma_atom_dO, tma_tensor_dO = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + mdO, + cute.select(self.sdO_layout, mode=[0, 1]), + (self.tile_m, self.tile_hdimv), + ) + if const_expr(self.qhead_per_kvhead == 1): + tma_atom_dK, tma_tensor_dK = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + mdK, + cute.select(self.sK_layout, mode=[0, 1]), + (self.tile_n, self.tile_hdim), + ) + tma_atom_dV, tma_tensor_dV = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + mdV, + cute.select(self.sV_layout, mode=[0, 1]), + (self.tile_n, self.tile_hdimv), + ) + else: + tma_atom_dK = tma_atom_dV = tma_tensor_dK = tma_tensor_dV = None + + TileScheduler = SingleTileScheduler + tile_sched_args = TileSchedulerArguments( + cute.ceil_div(cute.size(mK.shape[0]), self.tile_n), + cute.size(mQ.shape[2]), + cute.size(mQ.shape[3]), + 1, # num_splits + cute.size(mK.shape[0]), + mQ.shape[1], + mV.shape[1], + total_q=cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + tile_shape_mn=(self.tile_m, self.tile_n), + mCuSeqlensQ=None, + mSeqUsedQ=None, + qhead_per_kvhead_packgqa=1, + element_size=self.dtype.width // 8, + is_persistent=False, + lpt=False, + ) + + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + + LOG2_E = math.log2(math.e) + if const_expr(self.score_mod is None): + softmax_scale_log2 = softmax_scale * LOG2_E + else: + softmax_scale_log2 = LOG2_E + + fastdiv_mods = None + if const_expr(aux_tensors is not None): + seqlen_q = cute.size(mQ.shape[0]) + seqlen_k = cute.size(mK.shape[0]) + seqlen_q_divmod = FastDivmodDivisor(seqlen_q) + seqlen_k_divmod = FastDivmodDivisor(seqlen_k) + fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) + + qhead_per_kvhead_divmod = None + if const_expr(self.qhead_per_kvhead > 1): + qhead_per_kvhead_divmod = FastDivmodDivisor(self.qhead_per_kvhead) + + self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) + + self.kernel( + tma_tensor_Q, + tma_tensor_K, + tma_tensor_V, + tma_tensor_dO, + tma_tensor_dK if const_expr(self.qhead_per_kvhead == 1) else mdK, + tma_tensor_dV if const_expr(self.qhead_per_kvhead == 1) else mdV, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_dO, + tma_atom_dK, + tma_atom_dV, + mLSE, + mdPsum, + mdQaccum, + self.sQ_layout, + self.sK_layout, + self.sV_layout, + self.sPdS_layout, + self.sdO_layout, + self.sdQaccum_layout, + self.sdKVaccum_layout, + self.r2s_tiled_copy_dQaccum, + self.r2s_tiled_copy_dKVaccum, + tiled_mma_SdP, + tiled_mma_dK, + tiled_mma_dV, + tiled_mma_dQ, + softmax_scale_log2, + softmax_scale, + tile_sched_params, + TileScheduler, + SharedStorage, + aux_tensors, + fastdiv_mods, + blocksparse_tensors, + qhead_per_kvhead_divmod, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + smem=SharedStorage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mdO: cute.Tensor, + mdK: cute.Tensor, + mdV: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_dO: cute.CopyAtom, + tma_atom_dK: cute.CopyAtom, + tma_atom_dV: cute.CopyAtom, + mLSE: cute.Tensor, + mdPsum: cute.Tensor, + mdQaccum: cute.Tensor, + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sPdS_layout: cute.ComposedLayout, + sdO_layout: cute.ComposedLayout, + sdQaccum_layout: cute.Layout, + sdKVaccum_layout: cute.Layout, + r2s_tiled_copy_dQaccum: cute.TiledCopy, + r2s_tiled_copy_dKVaccum: cute.TiledCopy, + tiled_mma_SdP: cute.TiledMma, + tiled_mma_dK: cute.TiledMma, + tiled_mma_dV: cute.TiledMma, + tiled_mma_dQ: cute.TiledMma, + softmax_scale_log2, + softmax_scale, + tile_sched_params: ParamsBase, + TileScheduler: cutlass.Constexpr[Callable], + SharedStorage: cutlass.Constexpr[Callable], + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), + blocksparse_tensors: Optional[BlockSparseTensors] = None, + qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + # prefetch TMA descriptors + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V) + cpasync.prefetch_descriptor(tma_atom_dO) + + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + pipeline_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread) + pipeline_consumer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, self.num_mma_threads // cute.arch.WARP_SIZE + ) + pipeline_Q = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.mbar_ptr_Q.data_ptr(), + num_stages=self.Q_stage, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group, + tx_count=self.tma_copy_bytes["Q"] + self.tma_copy_bytes["LSE"], + defer_sync=True, + ) + pipeline_dO = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.mbar_ptr_dO.data_ptr(), + num_stages=self.dO_stage, + producer_group=pipeline_producer_group, + consumer_group=pipeline_consumer_group, + tx_count=self.tma_copy_bytes["dO"] + self.tma_copy_bytes["dPsum"], + defer_sync=False, + ) + + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sdO = storage.sdO.get_tensor(sdO_layout.outer, swizzle=sdO_layout.inner) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + sP = None + if const_expr(not self.mma_dkv_is_rs): + sP = storage.sP.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner) + sdS = storage.sdS.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner) + sLSE = storage.sLSE.get_tensor( + cute.make_layout( + (self.tile_m, self.Q_stage), + stride=(1, cute.round_up(self.tile_m, 64)), + ) + ) + sdPsum = storage.sdPsum.get_tensor( + cute.make_layout( + (self.tile_m, self.dO_stage), + stride=(1, cute.round_up(self.tile_m, 64)), + ) + ) + sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout) + + block_info = BlockInfo( + self.tile_m, + self.tile_n, + self.is_causal, + self.is_local, + False, # is_split_kv + None, + None, + qhead_per_kvhead_packgqa=1, + ) + SeqlenInfoCls = partial( + SeqlenInfoQK.create, + seqlen_q_static=mQ.shape[0], + seqlen_k_static=mK.shape[0], + mCuSeqlensQ=None, + mCuSeqlensK=None, + mSeqUsedQ=None, + mSeqUsedK=None, + ) + AttentionMaskCls = partial( + AttentionMask, + self.tile_m, + self.tile_n, + window_size_left=None, + window_size_right=None, + swap_AB=self.SdP_swapAB, + ) + TileSchedulerCls = partial(TileScheduler.create, tile_sched_params) + + if warp_idx < 4: + cute.arch.warpgroup_reg_dealloc(self.num_producer_regs) + if warp_idx == 0: + self.load( + mQ, + mK, + mV, + mdO, + mLSE, + mdPsum, + sQ, + sK, + sV, + sdO, + sLSE, + sdPsum, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_dO, + pipeline_Q, + pipeline_dO, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + blocksparse_tensors, + qhead_per_kvhead_divmod, + ) + if warp_idx == 1: + for warp_group_idx in cutlass.range(self.num_mma_warp_groups): + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, + number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, + ) + self.dQaccum_store( + mdQaccum, + sdQaccum, + block_info, + TileSchedulerCls, + SeqlenInfoCls, + blocksparse_tensors, + ) + else: + cute.arch.warpgroup_reg_alloc(self.num_mma_regs) + tidx, _, _ = cute.arch.thread_idx() + tidx = tidx - 128 + self.mma( + tiled_mma_SdP, + tiled_mma_dK, + tiled_mma_dV, + tiled_mma_dQ, + mdK, + mdV, + mdQaccum, + sQ, + sK, + sV, + sdO, + sP, + sdS, + sLSE, + sdPsum, + sdQaccum, + pipeline_Q, + pipeline_dO, + tidx, + tma_atom_dK, + tma_atom_dV, + r2s_tiled_copy_dQaccum, + r2s_tiled_copy_dKVaccum, + sdKVaccum_layout, + softmax_scale_log2, + softmax_scale, + block_info, + SeqlenInfoCls, + AttentionMaskCls, + TileSchedulerCls, + aux_tensors, + fastdiv_mods, + blocksparse_tensors, + qhead_per_kvhead_divmod, + ) + + @cute.jit + def load( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mdO: cute.Tensor, + mLSE: cute.Tensor, + mdPsum: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sdO: cute.Tensor, + sLSE: cute.Tensor, + sdPsum: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_dO: cute.CopyAtom, + pipeline_Q: cutlass.pipeline.PipelineAsync, + pipeline_dO: cutlass.pipeline.PipelineAsync, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors] = None, + qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None, + ): + warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 + + if warp_idx_in_wg == 0: + producer_state_Q = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.Q_stage + ) + producer_state_dO = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.dO_stage + ) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + n_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + head_idx_kv = ( + head_idx + if const_expr(self.qhead_per_kvhead == 1) + else head_idx // qhead_per_kvhead_divmod + ) + mK_cur = mK[None, None, head_idx_kv, batch_idx] + gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) + mV_cur = mV[None, None, head_idx_kv, batch_idx] + gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0)) + + mQ_cur = mQ[None, None, head_idx, batch_idx] + gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (None, 0)) + mdO_cur = mdO[None, None, head_idx, batch_idx] + gdO = cute.local_tile(mdO_cur, (self.tile_m, self.tile_hdimv), (None, 0)) + mLSE_cur = mLSE[None, head_idx, batch_idx] + gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,)) + mdPsum_cur = mdPsum[None, head_idx, batch_idx] + gdPsum = cute.local_tile(mdPsum_cur, (self.tile_m,), (None,)) + + load_K, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_K, 0, cute.make_layout(1), gK, sK, single_stage=True + ) + load_V, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_V, 0, cute.make_layout(1), gV, sV, single_stage=True + ) + load_Q, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ, sQ + ) + load_Q = copy_utils.tma_producer_copy_fn(load_Q, pipeline_Q) + load_dO, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_dO, 0, cute.make_layout(1), gdO, sdO + ) + load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_dO) + load_LSE = copy_utils.cpasync_bulk_get_copy_fn(gLSE, sLSE) + load_LSE = copy_utils.tma_producer_copy_fn(load_LSE, pipeline_Q) + load_dPsum = copy_utils.cpasync_bulk_get_copy_fn(gdPsum, sdPsum) + load_dPsum = copy_utils.tma_producer_copy_fn(load_dPsum, pipeline_dO) + + m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) + + if const_expr(not self.use_block_sparsity): + total_m_block_cnt = m_block_max - m_block_min + process_tile = const_expr(not self.is_local) or m_block_min < m_block_max + else: + total_m_block_cnt = get_total_q_block_count_bwd( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, + ) + process_tile = total_m_block_cnt > Int32(0) + + if process_tile: + if const_expr(not self.use_block_sparsity): + first_m_block = m_block_min + pipeline_Q.producer_acquire( + producer_state_Q, extra_tx_count=self.tma_copy_bytes["K"] + ) + load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q)) + load_Q(first_m_block, producer_state=producer_state_Q) + with cute.arch.elect_one(): + load_LSE(first_m_block, producer_state=producer_state_Q) + producer_state_dO_cur = ( + producer_state_dO + if const_expr(self.Q_stage != self.dO_stage) + else producer_state_Q + ) + pipeline_dO.producer_acquire( + producer_state_dO_cur, extra_tx_count=self.tma_copy_bytes["V"] + ) + load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur)) + load_dO(first_m_block, producer_state=producer_state_dO_cur) + with cute.arch.elect_one(): + load_dPsum(first_m_block, producer_state=producer_state_dO_cur) + producer_state_Q.advance() + producer_state_dO.advance() + + for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): + pipeline_Q.producer_acquire(producer_state_Q) + load_Q(m_block, producer_state=producer_state_Q) + with cute.arch.elect_one(): + load_LSE(m_block, producer_state=producer_state_Q) + producer_state_dO_cur = ( + producer_state_dO + if const_expr(self.Q_stage != self.dO_stage) + else producer_state_Q + ) + pipeline_dO.producer_acquire(producer_state_dO_cur) + load_dO(m_block, producer_state=producer_state_dO_cur) + with cute.arch.elect_one(): + load_dPsum(m_block, producer_state=producer_state_dO_cur) + producer_state_Q.advance() + producer_state_dO.advance() + else: + producer_state_Q, producer_state_dO = produce_block_sparse_q_loads_bwd_sm90( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + producer_state_Q, + producer_state_dO, + pipeline_Q, + pipeline_dO, + load_K, + load_V, + load_Q, + load_dO, + load_LSE, + load_dPsum, + self.tma_copy_bytes["K"], + self.tma_copy_bytes["V"], + Q_stage_eq_dO_stage=(self.Q_stage == self.dO_stage), + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, + ) + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit + def apply_score_mod( + self, + acc_S: cute.Tensor, + thr_mma_SdP: cute.core.ThrMma, + batch_idx, + head_idx, + m_block, + n_block, + softmax_scale, + seqlen_info: SeqlenInfoQK, + aux_tensors=None, + fastdiv_mods=(None, None), + ): + # [NOTE] SdP_swapAB: swapAB transposes the tile, so use (n, m) indexing + cS = cute.make_identity_tensor( + (self.tile_n, self.tile_m) if self.SdP_swapAB else (self.tile_m, self.tile_n) + ) + cS = cute.domain_offset( + (n_block * self.tile_n, m_block * self.tile_m) + if self.SdP_swapAB + else (m_block * self.tile_m, n_block * self.tile_n), + cS, + ) + tScS = thr_mma_SdP.partition_C(cS) + + apply_score_mod_inner( + acc_S, + tScS, + self.score_mod, + batch_idx, + head_idx, + softmax_scale, + self.vec_size, + self.qk_acc_dtype, + aux_tensors, + fastdiv_mods, + seqlen_info, + constant_q_idx=None, + qhead_per_kvhead=self.qhead_per_kvhead, + transpose_indices=self.SdP_swapAB, + ) + + @cute.jit + def apply_score_mod_bwd( + self, + grad_tensor: cute.Tensor, + score_tensor: cute.Tensor, + thr_mma_SdP: cute.core.ThrMma, + batch_idx, + head_idx, + m_block, + n_block, + softmax_scale, + seqlen_info: SeqlenInfoQK, + aux_tensors=None, + fastdiv_mods=(None, None), + ): + cS = cute.make_identity_tensor( + (self.tile_n, self.tile_m) if self.SdP_swapAB else (self.tile_m, self.tile_n) + ) + cS = cute.domain_offset( + (n_block * self.tile_n, m_block * self.tile_m) + if self.SdP_swapAB + else (m_block * self.tile_m, n_block * self.tile_n), + cS, + ) + tScS = thr_mma_SdP.partition_C(cS) + + apply_score_mod_bwd_inner( + grad_tensor, + score_tensor, + tScS, + self.score_mod_bwd, + batch_idx, + head_idx, + softmax_scale, + self.vec_size, + self.qk_acc_dtype, + aux_tensors, + fastdiv_mods, + seqlen_info, + constant_q_idx=None, + qhead_per_kvhead=self.qhead_per_kvhead, + transpose_indices=self.SdP_swapAB, + ) + + @cute.jit + def mma( + self, + tiled_mma_SdP: cute.TiledMma, + tiled_mma_dK: cute.TiledMma, + tiled_mma_dV: cute.TiledMma, + tiled_mma_dQ: cute.TiledMma, + mdK: cute.Tensor, + mdV: cute.Tensor, + mdQaccum: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sdO: cute.Tensor, + sP: Optional[cute.Tensor], + sdS: cute.Tensor, + sLSE: cute.Tensor, + sdPsum: cute.Tensor, + sdQaccum: cute.Tensor, + pipeline_Q: cutlass.pipeline.PipelineAsync, + pipeline_dO: cutlass.pipeline.PipelineAsync, + tidx: Int32, + tma_atom_dK: cute.CopyAtom, + tma_atom_dV: cute.CopyAtom, + r2s_tiled_copy_dQaccum: cute.TiledCopy, + r2s_tiled_copy_dKVaccum: cute.TiledCopy, + sdKVaccum_layout: cute.Layout, + softmax_scale_log2: Float32, + softmax_scale: Float32, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + AttentionMaskCls: Callable, + TileSchedulerCls: Callable, + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), + blocksparse_tensors: Optional[BlockSparseTensors] = None, + qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None, + ): + warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) + warp_group_thread_layout = cute.make_layout( + self.num_mma_warp_groups, stride=self.num_threads_per_warp_group + ) + thr_mma_SdP = tiled_mma_SdP.get_slice(tidx) + wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx)) + wg_mma_dK = tiled_mma_dK.get_slice(warp_group_thread_layout(warp_group_idx)) + wg_mma_dV = tiled_mma_dV.get_slice(warp_group_thread_layout(warp_group_idx)) + wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout(warp_group_idx)) + # S = Q @ K.T + tSrQ, tSrK = mma_partition_fragment_AB(wg_mma_SdP, sQ, sK, self.SdP_swapAB) + # dP = dO @ V.T + tdPrdO, tdPrV = mma_partition_fragment_AB(wg_mma_SdP, sdO, sV, self.SdP_swapAB) + # dV += P.T @ dO + sPt = utils.transpose_view(sP) if sP is not None else None + sdOt = utils.transpose_view(sdO) + tdVrPt, tdVrdOt = mma_partition_fragment_AB(wg_mma_dV, sPt, sdOt, self.dKV_swapAB) + # dK += dS.T @ Q + sdSt = utils.transpose_view(sdS) + sQt = utils.transpose_view(sQ) + tdKrdSt, tdKrQt = mma_partition_fragment_AB(wg_mma_dK, sdSt, sQt, self.dKV_swapAB) + # dQ = dS @ K + sKt = utils.transpose_view(sK) + tdQrdS, tdQrKt = mma_partition_fragment_AB(wg_mma_dQ, sdS, sKt, self.dQ_swapAB) + + # Smem copy atom tiling + smem_copy_atom_PdS = utils.get_smem_store_atom( + self.arch, self.dtype, transpose=self.SdP_swapAB + ) + smem_thr_copy_PdS = cute.make_tiled_copy_C(smem_copy_atom_PdS, tiled_mma_SdP).get_slice( + tidx + ) + tPsP = None + if const_expr(sP is not None): + tPsP = smem_thr_copy_PdS.partition_D(sP if const_expr(not self.SdP_swapAB) else sPt) + tdSsdS = smem_thr_copy_PdS.partition_D(sdS if const_expr(not self.SdP_swapAB) else sdSt) + + sLSE_mma = cute.make_tensor( + sLSE.iterator, + cute.make_layout( + (self.tile_m, self.tile_n, self.Q_stage), + stride=(1, 0, cute.round_up(self.tile_m, 64)), + ), + ) + sdPsum_mma = cute.make_tensor( + sdPsum.iterator, + cute.make_layout( + (self.tile_m, self.tile_n, self.dO_stage), + stride=(1, 0, cute.round_up(self.tile_m, 64)), + ), + ) + if const_expr(self.SdP_swapAB): + sLSE_mma = utils.transpose_view(sLSE_mma) + sdPsum_mma = utils.transpose_view(sdPsum_mma) + LSEslice = (None, 0, None) if const_expr(not self.SdP_swapAB) else (0, None, None) + tLSEsLSE = utils.make_acc_tensor_mn_view(thr_mma_SdP.partition_C(sLSE_mma))[LSEslice] + tLSEsdPsum = utils.make_acc_tensor_mn_view(thr_mma_SdP.partition_C(sdPsum_mma))[LSEslice] + + smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx) + tdQsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum) + + dV_shape = (self.tile_n, self.tile_hdimv) + acc_dV = cute.make_fragment( + tiled_mma_dV.partition_shape_C(dV_shape if not self.dKV_swapAB else dV_shape[::-1]), + Float32, + ) + dK_shape = (self.tile_n, self.tile_hdim) + acc_dK = cute.make_fragment( + tiled_mma_dK.partition_shape_C(dK_shape if not self.dKV_swapAB else dK_shape[::-1]), + Float32, + ) + + mma_qk_fn = partial( + gemm_zero_init, + tiled_mma_SdP, + (self.tile_m, self.tile_n), + tSrQ, + tSrK, + swap_AB=self.SdP_swapAB, + ) + mma_dov_fn = partial( + gemm_zero_init, + tiled_mma_SdP, + (self.tile_m, self.tile_n), + tdPrdO, + tdPrV, + swap_AB=self.SdP_swapAB, + ) + if const_expr(not self.mma_dkv_is_rs): + mma_pdo_fn = partial( + gemm_w_idx, tiled_mma_dV, acc_dV, tdVrPt, tdVrdOt, swap_AB=self.dKV_swapAB + ) + mma_dsq_fn = partial( + gemm_w_idx, tiled_mma_dK, acc_dK, tdKrdSt, tdKrQt, swap_AB=self.dKV_swapAB + ) + else: + assert not self.dKV_swapAB + mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, acc_dV, tCrB=tdVrdOt) + mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, acc_dK, tCrB=tdKrQt) + mma_dsk_fn = partial( + gemm_zero_init, + tiled_mma_dQ, + (self.tile_m, self.tile_hdim), + tdQrdS, + tdQrKt, + swap_AB=self.dQ_swapAB, + ) + + mma_one_m_block_all = partial( + self.mma_one_m_block, + warp_group_idx=warp_group_idx, + mma_qk_fn=mma_qk_fn, + mma_dov_fn=mma_dov_fn, + mma_pdo_fn=mma_pdo_fn, + mma_dsq_fn=mma_dsq_fn, + mma_dsk_fn=mma_dsk_fn, + pipeline_Q=pipeline_Q, + pipeline_dO=pipeline_dO, + tLSEsLSE=tLSEsLSE, + tLSEsdPsum=tLSEsdPsum, + tPsP=tPsP, + tdSsdS=tdSsdS, + tdQsdQaccum=tdQsdQaccum, + smem_thr_copy_PdS=smem_thr_copy_PdS, + smem_thr_copy_dQaccum=smem_thr_copy_dQaccum, + softmax_scale_log2=softmax_scale_log2, + # acc_dV=acc_dV, + # acc_dK=acc_dK, + ) + + consumer_state_Q = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage + ) + consumer_state_dO = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage + ) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + n_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + mask = AttentionMaskCls(seqlen) + m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) + + if const_expr(not self.use_block_sparsity): + process_tile = const_expr(not self.is_local) or m_block_min < m_block_max + else: + total_m_block_cnt = get_total_q_block_count_bwd( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, + ) + process_tile = total_m_block_cnt > Int32(0) + + if process_tile: + if const_expr(not self.use_block_sparsity): + mask_fn = partial( + mask.apply_mask, + batch_idx=batch_idx, + head_idx=head_idx, + n_block=n_block, + thr_mma=thr_mma_SdP, + mask_seqlen=True, + mask_causal=self.is_causal, + mask_local=self.is_local, + mask_mod=self.mask_mod, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + dKV_accumulate = False + for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): + consumer_state_Q, consumer_state_dO = mma_one_m_block_all( + m_block, + consumer_state_Q, + consumer_state_dO, + mask_fn=mask_fn, + dKV_accumulate=dKV_accumulate, + thr_mma_SdP=thr_mma_SdP, + batch_idx=batch_idx, + head_idx=head_idx, + n_block=n_block, + softmax_scale=softmax_scale, + seqlen=seqlen, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + dKV_accumulate = True + else: + consumer_state_Q, consumer_state_dO = consume_block_sparse_mma_bwd_sm90( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + consumer_state_Q, + consumer_state_dO, + mma_one_m_block_all, + mask, + self.mask_mod, + is_causal=self.is_causal, + is_local=self.is_local, + thr_mma_SdP=thr_mma_SdP, + softmax_scale=softmax_scale, + seqlen=seqlen, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + + if const_expr(self.qhead_per_kvhead == 1): + acc_dK.store(acc_dK.load() * softmax_scale) + self.epilogue_dKV( + acc_dV, + mdV, + sV, + acc_dK, + mdK, + sK, + seqlen, + tma_atom_dK, + tma_atom_dV, + tiled_mma_dK, + tiled_mma_dV, + r2s_tiled_copy_dKVaccum, + sdKVaccum_layout, + tidx, + n_block, + head_idx, + batch_idx, + qhead_per_kvhead_divmod, + ) + else: + # Block sparsity: KV tile with zero Q blocks produces no dK/dV; write zeros. + if const_expr(self.use_block_sparsity): + acc_dK.fill(0.0) + acc_dV.fill(0.0) + self.epilogue_dKV( + acc_dV, + mdV, + sV, + acc_dK, + mdK, + sK, + seqlen, + tma_atom_dK, + tma_atom_dV, + tiled_mma_dK, + tiled_mma_dV, + r2s_tiled_copy_dKVaccum, + sdKVaccum_layout, + tidx, + n_block, + head_idx, + batch_idx, + qhead_per_kvhead_divmod, + ) + + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit + def mma_one_m_block( + self, + m_block: Int32, + consumer_state_Q: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + consumer_state_dO: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + warp_group_idx: Int32, + mma_qk_fn: Callable, + mma_dov_fn: Callable, + mma_pdo_fn: Callable, + mma_dsq_fn: Callable, + mma_dsk_fn: Callable, + pipeline_Q: cutlass.pipeline.PipelineAsync, + pipeline_dO: cutlass.pipeline.PipelineAsync, + tLSEsLSE: cute.Tensor, + tLSEsdPsum: cute.Tensor, + tPsP: Optional[cute.Tensor], + tdSsdS: Optional[cute.Tensor], + tdQsdQaccum: cute.Tensor, + smem_thr_copy_PdS: cute.TiledCopy, + smem_thr_copy_dQaccum: cute.TiledCopy, + softmax_scale_log2: Float32, + mask_fn: Optional[Callable] = None, + dKV_accumulate: Boolean = True, + thr_mma_SdP: Optional[cute.core.ThrMma] = None, + batch_idx: Int32 = 0, + head_idx: Int32 = 0, + n_block: Int32 = 0, + softmax_scale: Float32 = 1.0, + seqlen: Optional[SeqlenInfoQK] = None, + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), + ): + consumer_state_dO_cur = ( + consumer_state_dO if const_expr(self.Q_stage == self.dO_stage) else consumer_state_Q + ) + smem_idx_Q = consumer_state_Q.index + smem_idx_dO = consumer_state_dO_cur.index if const_expr(self.dO_stage > 1) else 0 + smem_idx_PdS = smem_idx_Q if const_expr(self.PdS_stage > 1) else 0 + # (1) [GEMM 1] S = Q @ K^T + pipeline_Q.consumer_wait(consumer_state_Q, pipeline_Q.consumer_try_wait(consumer_state_Q)) + acc_S = mma_qk_fn(A_idx=smem_idx_Q, wg_wait=-1) + tLSErLSE = copy_utils.load_s2r(tLSEsLSE[None, smem_idx_Q]) + # (2) [GEMM 2] dP = dO @ V.T + pipeline_dO.consumer_wait( + consumer_state_dO_cur, pipeline_dO.consumer_try_wait(consumer_state_dO_cur) + ) + acc_dP = mma_dov_fn(A_idx=smem_idx_Q, wg_wait=1) + + if const_expr(self.score_mod_bwd is not None): + acc_S_pre = cute.make_fragment_like(acc_S) + cute.autovec_copy(acc_S, acc_S_pre) + + if const_expr(self.score_mod is not None): + self.apply_score_mod( + acc_S, + thr_mma_SdP, + batch_idx, + head_idx, + m_block, + n_block, + softmax_scale, + seqlen, + aux_tensors, + fastdiv_mods, + ) + + # (3) [Pointwise 1] P = exp(S - LSE) + if cutlass.const_expr(mask_fn is not None): + mask_fn(acc_S, m_block=m_block) + acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.SdP_swapAB) + for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])): + for c in cutlass.range(cute.size(acc_S_mn, mode=[1]), unroll_full=True): + acc_S_mn[r, c] = cute.math.exp2( + acc_S_mn[r, c] * softmax_scale_log2 - tLSErLSE[r], fastmath=True + ) + tLSErdPsum = copy_utils.load_s2r(tLSEsdPsum[None, smem_idx_dO]) + + # Convert P from f32 -> f16 + tdVrP = utils.cvt_f16(utils.make_acc_tensor_frgA_view(acc_S), self.dtype) + # R2S for P + if const_expr(not self.mma_dkv_is_rs): + # sync to ensure P has already been used in the previous iteration before overwriting + if const_expr(self.PdS_stage == 1): + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads + ) + tPrP = smem_thr_copy_PdS.retile(tdVrP) + cute.copy(smem_thr_copy_PdS, tPrP, tPsP[None, None, None, smem_idx_PdS]) + + # (4) [Pointwise 2] dS = P*(dP-dPsum) + warpgroup.wait_group(0) + acc_dP_mn = utils.make_acc_tensor_mn_view(acc_dP, transpose=self.SdP_swapAB) + for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])): + for c in cutlass.range(cute.size(acc_dP_mn, mode=[1]), unroll_full=True): + acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - tLSErdPsum[r]) + + if const_expr(self.score_mod_bwd is not None): + self.apply_score_mod_bwd( + acc_dP, + acc_S_pre, + thr_mma_SdP, + batch_idx, + head_idx, + m_block, + n_block, + softmax_scale, + seqlen, + aux_tensors, + fastdiv_mods, + ) + + # Convert dS from f32 -> f16 + tdKrdS = utils.cvt_f16(utils.make_acc_tensor_frgA_view(acc_dP), self.dtype) + + # If there's double buffering on dS, we don't need to sync here. + # Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ. + # But because both WGs have to sync at the end of the loop and double buffering, + # this race condition is not possible. + # This sync is to ensure (1) P is written in case of !mma_dkv_is_rs and + # (2) dS is already read by the Mma in the previous iteration in case of mma_dkv_is_rs. + if const_expr(not self.mma_dkv_is_rs or (self.PdS_stage == 1 and self.mma_dkv_is_rs)): + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads + ) + + # R2S for dS + tdSrdS = smem_thr_copy_PdS.retile(tdKrdS) + cute.copy(smem_thr_copy_PdS, tdSrdS, tdSsdS[None, None, None, smem_idx_PdS]) + + # (5) [GEMM 3] dV += P.T @ dO + if const_expr(not self.mma_dkv_is_rs): + mma_pdo_fn( + A_idx=smem_idx_PdS, B_idx=smem_idx_dO, zero_init=not dKV_accumulate, wg_wait=-1 + ) + else: + mma_pdo_fn(tCrA=tdVrP, B_idx=smem_idx_dO, zero_init=not dKV_accumulate, wg_wait=-1) + + # smem fence to make sure sdS is written before it's read by WGMMA + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads + ) + # (6) [GEMM 4] dQ = dS @ K + acc_dQ = mma_dsk_fn(A_idx=smem_idx_PdS, wg_wait=1) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dV) + pipeline_dO.consumer_release(consumer_state_dO_cur) # release dO as dV mma is done + + # (7) [GEMM 5] dK += dS.T @ Q + if const_expr(not self.mma_dkv_is_rs): + mma_dsq_fn( + A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1 + ) + else: + mma_dsq_fn(tCrA=tdKrdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dQ) + + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, + number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, + ) + tdQrdQaccum_flat = cute.make_tensor(acc_dQ.iterator, cute.make_layout(tdQsdQaccum.shape)) + cute.autovec_copy(tdQrdQaccum_flat, tdQsdQaccum) + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, + number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, + ) + + warpgroup.wait_group(0) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dK) + pipeline_Q.consumer_release(consumer_state_Q) + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("tidx = {}, m_block = {}, after pipeline_Q consumer release", cute.arch.thread_idx()[0], m_block) + + consumer_state_Q.advance() + consumer_state_dO.advance() + return consumer_state_Q, consumer_state_dO + + @cute.jit + def epilogue_dKV( + self, + acc_dV: cute.Tensor, + mdV: cute.Tensor, + sV: cute.Tensor, + acc_dK: cute.Tensor, + mdK: cute.Tensor, + sK: cute.Tensor, + seqlen: SeqlenInfoQK, + tma_atom_dK: cute.CopyAtom, + tma_atom_dV: cute.CopyAtom, + tiled_mma_dK: cute.TiledMma, + tiled_mma_dV: cute.TiledMma, + r2s_tiled_copy_dKVaccum: cute.TiledCopy, + sdKVaccum_layout: cute.Layout, + tidx: Int32, + n_block: Int32, + head_idx: Int32, + batch_idx: Int32, + qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + if const_expr(self.qhead_per_kvhead == 1): + rdV = cute.make_fragment_like(acc_dV, self.dtype) + rdV.store(acc_dV.load().to(self.dtype)) + rdK = utils.cvt_f16(acc_dK, self.dtype) + + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads + ) + + smem_copy_atom_dKV = cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=self.dKV_swapAB, num_matrices=4), + self.dtype, + ) + smem_thr_copy_dK = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma_dK).get_slice( + tidx + ) + smem_thr_copy_dV = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma_dV).get_slice( + tidx + ) + mdV_cur = mdV[None, None, head_idx, batch_idx] + mdK_cur = mdK[None, None, head_idx, batch_idx] + gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) + gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0)) + store_dK, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_dK, 0, cute.make_layout(1), sK, gdK, single_stage=True + ) + store_dV, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_dV, 0, cute.make_layout(1), sV, gdV, single_stage=True + ) + + taccdVrdV = smem_thr_copy_dV.retile(rdV) + sdV = sV if const_expr(not self.dKV_swapAB) else utils.transpose_view(sV) + taccdVsdV = smem_thr_copy_dV.partition_D(sdV) + cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV) + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads + ) + if warp_idx == 4: + store_dV() + taccdKrdK = smem_thr_copy_dK.retile(rdK) + sdK = sK if const_expr(not self.dKV_swapAB) else utils.transpose_view(sK) + taccdKsdK = smem_thr_copy_dK.partition_D(sdK) + cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK) + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads + ) + if warp_idx == 4: + store_dK() + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + else: + head_idx_kv = head_idx // qhead_per_kvhead_divmod + + mdKaccum_cur = mdK[None, head_idx_kv, batch_idx] + gdKaccum_ = cute.local_tile(mdKaccum_cur, (self.tile_n * self.tile_hdim,), (n_block,)) + gdKaccum = cute.flat_divide( + gdKaccum_, (self.tile_n * self.tile_hdim // self.num_mma_warp_groups,) + ) + + mdVaccum_cur = mdV[None, head_idx_kv, batch_idx] + gdVaccum_ = cute.local_tile(mdVaccum_cur, (self.tile_n * self.tile_hdimv,), (n_block,)) + gdVaccum = cute.flat_divide( + gdVaccum_, (self.tile_n * self.tile_hdimv // self.num_mma_warp_groups,) + ) + + sdKVaccum = cute.make_tensor( + cute.recast_ptr(sV.iterator, dtype=Float32), + sdKVaccum_layout, + ) + + smem_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_slice(tidx) + tdKsdKVaccum = smem_thr_copy_dKVaccum.partition_D(sdKVaccum) + + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads + ) + + tdKrdKaccum_flat = cute.make_tensor( + acc_dK.iterator, cute.make_layout(tdKsdKVaccum.shape) + ) + cute.autovec_copy(tdKrdKaccum_flat, tdKsdKVaccum) + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads + ) + + if warp_idx == 4: + with cute.arch.elect_one(): + for wg_idx in cutlass.range_constexpr(self.num_mma_warp_groups): + copy_utils.cpasync_reduce_bulk_add_f32( + sdKVaccum[None, wg_idx].iterator, + gdKaccum[None, wg_idx].iterator, + self.tma_copy_bytes["dKacc"] // self.num_mma_warp_groups, + ) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads + ) + + tdVrdVaccum_flat = cute.make_tensor( + acc_dV.iterator, cute.make_layout(tdKsdKVaccum.shape) + ) + cute.autovec_copy(tdVrdVaccum_flat, tdKsdKVaccum) + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads + ) + + if warp_idx == 4: + with cute.arch.elect_one(): + for wg_idx in cutlass.range_constexpr(self.num_mma_warp_groups): + copy_utils.cpasync_reduce_bulk_add_f32( + sdKVaccum[None, wg_idx].iterator, + gdVaccum[None, wg_idx].iterator, + self.tma_copy_bytes["dVacc"] // self.num_mma_warp_groups, + ) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + + @cute.jit + def dQaccum_store( + self, + mdQaccum: cute.Tensor, + sdQaccum: cute.Tensor, + block_info: BlockInfo, + TileSchedulerCls: cutlass.Constexpr[Callable], + SeqlenInfoCls: cutlass.Constexpr[Callable], + blocksparse_tensors: Optional[BlockSparseTensors] = None, + ): + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + n_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] + gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,)) + # (M * K / WG, WG, _) + gdQaccum = cute.flat_divide( + gdQaccum_, (self.tile_m * self.tile_hdim // self.num_mma_warp_groups,) + ) + m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) + if const_expr(not self.use_block_sparsity): + process_tile = const_expr(not self.is_local) or m_block_min < m_block_max + loop_count = m_block_max - m_block_min + else: + total_block_cnt = get_total_q_block_count_bwd( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, + ) + process_tile = total_block_cnt > Int32(0) + + if process_tile: + if const_expr(not self.use_block_sparsity): + for iter_idx in cutlass.range(loop_count, unroll=1): + m_block = m_block_min + iter_idx + m_block_safe = m_block + + for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups): + cute.arch.barrier( + barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, + number_of_threads=self.num_threads_per_warp_group + + cute.arch.WARP_SIZE, + ) + with cute.arch.elect_one(): + copy_utils.cpasync_reduce_bulk_add_f32( + sdQaccum[None, warp_group_idx].iterator, + gdQaccum[None, warp_group_idx, m_block_safe].iterator, + self.tma_copy_bytes["dQ"], + ) + cute.arch.cp_async_bulk_commit_group() + for warp_group_idx in cutlass.range_constexpr(self.num_mma_warp_groups): + cute.arch.cp_async_bulk_wait_group( + self.num_mma_warp_groups - 1 - warp_group_idx, read=True + ) + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, + number_of_threads=self.num_threads_per_warp_group + + cute.arch.WARP_SIZE, + ) + else: + dQaccum_store_block_sparse_bwd_sm90( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + sdQaccum, + gdQaccum, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, + num_mma_warp_groups=self.num_mma_warp_groups, + num_threads_per_warp_group=self.num_threads_per_warp_group, + tma_copy_bytes_dQ=self.tma_copy_bytes["dQ"], + ) + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py new file mode 100644 index 00000000000..c13cd267719 --- /dev/null +++ b/flash_attn/cute/flash_fwd.py @@ -0,0 +1,2484 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# A reimplementation of +# https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_kernel_sm80.h +# and https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_kernel_sm90.h +# from Cutlass C++ to Cute-DSL. +# Built on Cute-DSL example: https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py + +import math +from types import SimpleNamespace +from typing import Type, Callable, Optional, List +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass import Constexpr, Float32, Int32, const_expr, Boolean +from cutlass.cute.nvgpu import cpasync, warp, warpgroup +from cutlass.cute.arch import ProxyKind, SharedSpace +import cutlass.utils as utils_basic +from cutlass.utils import LayoutEnum +import cutlass.utils.hopper_helpers as sm90_utils_basic + +from quack import copy_utils as quack_copy_utils + +from flash_attn.cute import ampere_helpers as sm80_utils +from flash_attn.cute import hopper_helpers as sm90_utils +from flash_attn.cute import utils +from flash_attn.cute import copy_utils +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.softmax import Softmax, apply_score_mod_inner +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( + produce_block_sparse_loads, + consume_block_sparse_loads, +) +from flash_attn.cute import pipeline +from flash_attn.cute.pack_gqa import PackGQA +from flash_attn.cute.named_barrier import NamedBarrierFwd +from flash_attn.cute.tile_scheduler import ( + TileSchedulerArguments, + SingleTileScheduler, + SingleTileLPTScheduler, + SingleTileVarlenScheduler, + ParamsBase, +) +from cutlass.cute import FastDivmodDivisor + + +class FlashAttentionForwardBase: + arch: int = 80 + + def __init__( + self, + dtype: Type[cutlass.Numeric], + head_dim: int, + head_dim_v: Optional[int] = None, + qhead_per_kvhead: int = 1, + is_causal: bool = False, + is_local: bool = False, + pack_gqa: bool = True, + tile_m: int = 128, + tile_n: int = 128, + num_stages: int = 1, + num_threads: int = 128, + Q_in_regs: bool = False, + score_mod: Optional[cutlass.Constexpr] = None, + mask_mod: Optional[cutlass.Constexpr] = None, + has_aux_tensors: bool = False, + ): + """Initializes the configuration for a flash attention kernel. + + All contiguous dimensions must be at least 16 bytes aligned, which means that the head dimension + should be a multiple of 8. + + :param head_dim: head dimension + :type head_dim: int + :param tile_m: m block size + :type tile_m: int + :param tile_n: n block size + :type tile_n: int + :param num_threads: number of threads + :type num_threads: int + :param is_causal: is causal + :param score_mod: A callable that takes the attention scores and applies a modification. + Callable signature: ``score_mod(scores, batch_idx, head_idx, q_idx, kv_idx, aux_tensors) -> Any`` + :param mask_mod: A callable that takes the attention scores and returns a boolean representing whether that score should be masked. + Callable signature: ``mask_mod(batch_idx, head_idx, q_idx, kv_idx, aux_tensors) -> Boolean`` + """ + self.dtype = dtype + # padding head_dim to a multiple of 16 as k_block_size + hdim_multiple_of = 16 + self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + head_dim_v = head_dim_v if head_dim_v is not None else head_dim + self.same_hdim_kv = head_dim == head_dim_v + self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) + # Can save registers (and hence be faster) if we don't have to check hdim predication + self.check_hdim_oob = head_dim != self.tile_hdim + self.check_hdim_v_oob = head_dim_v != self.tile_hdimv + self.qhead_per_kvhead = qhead_per_kvhead + self.is_causal = is_causal + self.is_local = is_local + self.pack_gqa = pack_gqa + self.tile_m = tile_m + self.tile_n = tile_n + self.num_threads = num_threads + self.num_stages = num_stages + self.Q_in_regs = Q_in_regs + self.score_mod = score_mod + self.mask_mod = mask_mod + self.qk_acc_dtype = Float32 + if const_expr(has_aux_tensors): + self.vec_size: cutlass.Constexpr = 1 + else: + self.vec_size: cutlass.Constexpr = 2 + + @staticmethod + def can_implement( + dtype, + head_dim, + head_dim_v, + tile_m, + tile_n, + num_stages, + num_threads, + is_causal, + Q_in_regs=False, + ) -> bool: + """Check if the kernel can be implemented with the given parameters. + + :param dtype: data type + :type dtype: cutlass.Numeric + :param head_dim: head dimension + :type head_dim: int + :param tile_m: m block size + :type tile_m: int + :param tile_n: n block size + :type tile_n: int + :param num_threads: number of threads + :type num_threads: int + :param is_causal: is causal + :type is_causal: bool + + :return: True if the kernel can be implemented, False otherwise + :rtype: bool + """ + if dtype not in [cutlass.Float16, cutlass.BFloat16]: + return False + if head_dim % 8 != 0: + return False + if head_dim_v % 8 != 0: + return False + if tile_n % 16 != 0: + return False + if num_threads % 32 != 0: + return False + # Check if block size setting is out of shared memory capacity + # Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size + smem_usage_Q = tile_m * head_dim * 2 + smem_usage_K = tile_n * head_dim * num_stages * 2 + smem_usage_V = tile_n * head_dim_v * num_stages * 2 + smem_usage_QV = ( + (smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V) + ) + smem_usage = smem_usage_QV + smem_usage_K + # TODO: sm86 and sm89 + smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_80") + if smem_usage > smem_capacity: + return False + # Check if twice the block size is divisible by the number of threads + if (tile_m * 2) % num_threads != 0: + return False + return True + + def _check_type( + self, + mQ_type: Type[cutlass.Numeric], + mK_type: Type[cutlass.Numeric], + mV_type: Type[cutlass.Numeric], + mO_type: Type[cutlass.Numeric], + mLSE_type: Type[cutlass.Numeric] | None, + mCuSeqlensQ_type: Type[cutlass.Numeric] | None, + mCuSeqlensK_type: Type[cutlass.Numeric] | None, + mSeqUsedQ_type: Type[cutlass.Numeric] | None, + mSeqUsedK_type: Type[cutlass.Numeric] | None, + ): + # Get the data type and check if it is fp16 or bf16 + if const_expr(not (mQ_type == mK_type == mV_type == mO_type)): + raise TypeError("All tensors must have the same data type") + if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]): + raise TypeError("Only Float16 or BFloat16 is supported") + if const_expr(mLSE_type not in [None, Float32]): + raise TypeError("LSE tensor must be Float32") + if const_expr(mCuSeqlensQ_type not in [None, Int32]): + raise TypeError("cu_seqlens_q tensor must be Int32") + if const_expr(mCuSeqlensK_type not in [None, Int32]): + raise TypeError("cu_seqlens_k tensor must be Int32") + if const_expr(mSeqUsedQ_type not in [None, Int32]): + raise TypeError("seqused_q tensor must be Int32") + if const_expr(mSeqUsedK_type not in [None, Int32]): + raise TypeError("seqused_k tensor must be Int32") + assert mQ_type == self.dtype + + def _setup_attributes(self): + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory layout: Q/K/V + # /////////////////////////////////////////////////////////////////////////////// + sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom = ( + self._get_smem_layout_atom() + ) + self.sQ_layout = cute.tile_to_shape( + sQ_layout_atom, + (self.tile_m, self.tile_hdim), + (0, 1), + ) + self.sK_layout = cute.tile_to_shape( + sK_layout_atom, + (self.tile_n, self.tile_hdim, self.num_stages), + (0, 1, 2), + ) + self.sV_layout = cute.tile_to_shape( + sV_layout_atom, + (self.tile_n, self.tile_hdimv, self.num_stages), + (0, 1, 2), + ) + self.sO_layout = cute.tile_to_shape( + sO_layout_atom, + (self.tile_m, self.tile_hdimv), + (0, 1), + ) + if const_expr(sP_layout_atom is not None): + self.sP_layout = cute.tile_to_shape( + sP_layout_atom, + (self.tile_m, self.tile_n), + (0, 1), + ) + else: + self.sP_layout = None + + # /////////////////////////////////////////////////////////////////////////////// + # GMEM Tiled copy: + # /////////////////////////////////////////////////////////////////////////////// + # Thread layouts for copies + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.dtype.width + # atom_async_copy: async copy atom for QKV load + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + self.dtype, + num_bits_per_copy=universal_copy_bits, + ) + # atom_universal_copy: universal copy atom for O store + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=universal_copy_bits, + ) + # tQ_layout and tK_layout: thread layout for QK load + tQK_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems + assert self.num_Q_load_threads % tQK_shape_dim_1 == 0, ( + "num_threads must be divisible by tQK_shape_dim_1" + ) + assert self.num_producer_threads % tQK_shape_dim_1 == 0, ( + "num_threads must be divisible by tQK_shape_dim_1" + ) + tQ_layout = cute.make_ordered_layout( + (self.num_Q_load_threads // tQK_shape_dim_1, tQK_shape_dim_1), + order=(1, 0), + ) + tK_layout = cute.make_ordered_layout( + (self.num_producer_threads // tQK_shape_dim_1, tQK_shape_dim_1), + order=(1, 0), + ) + # So that we don't have to check if we overshoot kBlockM when we load Q + assert self.tile_m % tQ_layout.shape[0] == 0 + tV_shape_dim_1 = sV_layout_atom.outer.shape[1] // async_copy_elems + tV_layout = cute.make_ordered_layout( + (self.num_producer_threads // tV_shape_dim_1, tV_shape_dim_1), + order=(1, 0), + ) + # TODO: need a different layout for O if O dtype is not the same as V dtype + # tO_layout: thread layout for O store + tO_layout = cute.make_ordered_layout( + (self.num_epilogue_threads // tV_shape_dim_1, tV_shape_dim_1), + order=(1, 0), + ) + # So that we don't have to check if we overshoot kBlockM when we store O + assert self.tile_m % tO_layout.shape[0] == 0 + + # Value layouts for copies + vQKV_layout = cute.make_layout((1, async_copy_elems)) + vO_layout = vQKV_layout + + self.gmem_tiled_copy_Q = cute.make_tiled_copy_tv(atom_async_copy, tQ_layout, vQKV_layout) + self.gmem_tiled_copy_K = cute.make_tiled_copy_tv(atom_async_copy, tK_layout, vQKV_layout) + self.gmem_tiled_copy_V = cute.make_tiled_copy_tv(atom_async_copy, tV_layout, vQKV_layout) + # gmem_tiled_copy_O: tiled copy for O store + self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) + + def _get_smem_layout_atom(self): + raise NotImplementedError() + + def _get_tiled_mma(self): + raise NotImplementedError() + + def _get_shared_storage_cls(self): + raise NotImplementedError() + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + softmax_scale: Float32, + stream: cuda.CUstream, + ): + """Configures and launches the flash attention kernel. + + mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: + (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) + """ + raise NotImplementedError() + + @cute.jit + def epilogue( + self, + acc_O: cute.Tensor, + lse: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + sO: cute.Tensor, + seqlen: SeqlenInfoQK, + gmem_tiled_copy_O: cute.TiledCopy, + tma_atom_O: Optional[cute.CopyAtom], + tiled_mma: cute.TiledMma, + tidx: Int32, + m_block: Int32, + head_idx: Int32, + batch_idx: Int32, + ): + # store acc_O + rO = cute.make_fragment_like(acc_O, self.dtype) + rO.store(acc_O.load().to(self.dtype)) + # Make sure all threads have finished reading V + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + ) + smem_copy_atom_O = utils.get_smem_store_atom(self.arch, self.dtype) + smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx) + taccOrO = smem_thr_copy_O.retile(rO) + taccOsO = smem_thr_copy_O.partition_D(sO) + # taccOsO = quack_copy_utils.partition_D_position_independent(smem_thr_copy_O, sO) + # copy acc O from rmem to smem with the smem copy atom + cute.copy(smem_copy_atom_O, taccOrO, taccOsO) + + cO = cute.make_identity_tensor((self.tile_m, self.tile_hdimv)) + pack_gqa = PackGQA( + self.tile_m, self.tile_hdimv, self.check_hdim_v_oob, self.qhead_per_kvhead + ) + + # Write LSE from rmem -> gmem + if const_expr(mLSE is not None): + if const_expr(not seqlen.has_cu_seqlens_q): + mLSE_cur = mLSE[None, head_idx, batch_idx] + else: + offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) + if const_expr(not self.pack_gqa): + gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,)) + gLSE_expanded_layout = cute.append( + gLSE.layout, cute.make_layout((self.tile_hdimv,), stride=(0,)) + ) + gLSE_expanded = cute.make_tensor(gLSE.iterator, gLSE_expanded_layout) + thr_mma = tiled_mma.get_slice(tidx) + taccOgLSE = utils.make_acc_tensor_mn_view(thr_mma.partition_C(gLSE_expanded)) + assert cute.size(taccOgLSE, mode=[0]) == cute.size(lse) + taccOcO = utils.make_acc_tensor_mn_view(thr_mma.partition_C(cO)) + t0accOcO = utils.make_acc_tensor_mn_view(thr_mma.get_slice(0).partition_C(cO)) + # Only the thread corresponding to column 0 writes out the lse to gmem + if taccOcO[0][1] == 0: + for m in cutlass.range_constexpr(cute.size(taccOgLSE.shape[1])): + if ( + t0accOcO[m, 0][0] + < seqlen.seqlen_q - m_block * self.tile_m - taccOcO[0][0] + ): + taccOgLSE[m, 0] = lse[m] + else: + pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q) + + if const_expr(not seqlen.has_cu_seqlens_q): + mO_cur = mO[None, None, head_idx, batch_idx] + else: + offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + mO_cur = cute.domain_offset((offset, 0), mO[None, None, head_idx]) + # thr_mma = tiled_mma.get_slice(tidx) + # taccOgO = thr_mma.partition_C(gO) + # cute.autovec_copy(rO, taccOgO) + # sync to make sure all smem stores are done + if const_expr(self.use_tma_O): + # ensure smem writes are visible to TMA + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE, + ) + gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0)) + store_O, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_O, 0, cute.make_layout(1), sO, gO, single_stage=True + ) + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + if warp_idx == 4: + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE, + ) + store_O() + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + else: + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=self.num_epilogue_threads, + ) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOsO = gmem_thr_copy_O.partition_S(sO) + tOrO = cute.make_fragment_like(tOsO, self.dtype) + # load acc O from smem to rmem for wider vectorization + cute.autovec_copy(tOsO, tOrO) + if const_expr(not self.pack_gqa): + gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0)) + tOgO = gmem_thr_copy_O.partition_D(gO) + tOcO = gmem_thr_copy_O.partition_S(cO) + t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) + # copy acc O from rmem to gmem + for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + if ( + t0OcO[0, rest_m, 0][0] + < seqlen.seqlen_q - m_block * self.tile_m - tOcO[0][0] + ): + cute.copy( + gmem_tiled_copy_O, + tOrO[None, rest_m, None], + tOgO[None, rest_m, None], + pred=tOpO[None, rest_m, None] + if const_expr(self.check_hdim_v_oob) + else None, + ) + else: + pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, m_block, seqlen.seqlen_q) + + @cute.jit + def advance_pipeline(self, pipeline_index): + return pipeline_index + 1 if pipeline_index < self.num_stages - 1 else 0 + + @cute.jit + def load_Q( + self, + gmem_thr_copy: cute.TiledCopy, + gQ: cute.Tensor, + sQ: cute.Tensor, + block: Int32, + seqlen: Int32, + headdim: Int32, + ): + tQsQ, tQgQ = gmem_thr_copy.partition_D(sQ), gmem_thr_copy.partition_S(gQ) + cQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim)) + tQcQ = gmem_thr_copy.partition_S(cQ) + t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ) + tQpQ = utils.predicate_k(tQcQ, limit=headdim) + for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): + # Instead of using tQcQ, we using t0QcQ and subtract the offset from the limit + # (seqlen - block * kBlockM). This is because the entries of t0QcQ are known at compile time. + if t0QcQ[0, m, 0][0] < seqlen - block * self.tile_m - tQcQ[0][0]: + cute.copy( + gmem_thr_copy, + tQgQ[None, m, None], + tQsQ[None, m, None], + pred=tQpQ[None, m, None] if const_expr(self.check_hdim_oob) else None, + ) + # We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + + @cute.jit + def load_K( + self, + gmem_tiled_copy: cute.TiledCopy, + tKgK: cute.Tensor, + tKsK: cute.Tensor, + tKcK: cute.Tensor, + t0KcK: cute.Tensor, + tKpK: cute.Tensor, + block: Int32, + smem_pipe_write: Int32, + seqlen: Int32, + need_predicates: cutlass.Constexpr, + ): + # Do we need to check if we overshoot kBlockN when we load K? + is_even_n_smem_k = self.tile_n % gmem_tiled_copy.tiler_mn[0].shape == 0 + if const_expr(need_predicates or not is_even_n_smem_k): + # Instead of using tKcK, we using t0KcK and subtract the offset from the limit + # (seqlen - block * kBlockN). This is because the entries of t0KcK are known at compile time. + if const_expr(is_even_n_smem_k): + seqlen_limit = seqlen - block * self.tile_n + else: + if const_expr(not need_predicates): + seqlen_limit = self.tile_n + else: + seqlen_limit = cutlass.min(seqlen - block * self.tile_n, self.tile_n) + seqlen_limit -= tKcK[0][0] + for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])): + if t0KcK[0, n, 0][0] < seqlen_limit: + cute.copy( + gmem_tiled_copy, + tKgK[None, n, None, block], + tKsK[ + None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0 + ], + pred=tKpK[None, n, None] if const_expr(self.check_hdim_oob) else None, + ) + # We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + else: + cute.copy( + gmem_tiled_copy, + tKgK[None, None, None, block], + tKsK[None, None, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0], + pred=tKpK if const_expr(self.check_hdim_oob) else None, + ) + + @cute.jit + def load_V( + self, + gmem_tiled_copy: cute.TiledCopy, + tVgV: cute.Tensor, + tVsV: cute.Tensor, + tVcV: cute.Tensor, + t0VcV: cute.Tensor, + tVpV: cute.Tensor, + block: Int32, + smem_pipe_write: Int32, + seqlen: Int32, + need_predicates: cutlass.Constexpr, + ): + # Do we need to check if we overshoot kBlockN when we load V? + is_even_n_smem_v = self.tile_n % gmem_tiled_copy.tiler_mn[0].shape == 0 + if const_expr(need_predicates or not is_even_n_smem_v): + for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])): + # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked + if ( + is_even_n_smem_v + or n < cute.size(tVsV.shape[1]) - 1 + or tVcV[0, n, 0][0] < self.tile_n + ): + predicate = tVpV[None, n, None] if const_expr(self.check_hdim_v_oob) else None + if const_expr(need_predicates): + seqlen_limit = seqlen - block * self.tile_n - tVcV[0][0] + predicate_n = t0VcV[0, n, 0][0] < seqlen_limit + predicate = cute.make_fragment_like(tVpV[None, 0, None]) + for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): + for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): + predicate[i, k] = ( + tVpV[i, n, k] if const_expr(self.check_hdim_v_oob) else True + ) and predicate_n + cute.copy( + gmem_tiled_copy, + tVgV[None, n, None, block], + tVsV[ + None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0 + ], + pred=predicate, + ) + else: + cute.copy( + gmem_tiled_copy, + tVgV[None, None, None, block], + tVsV[None, None, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0], + pred=tVpV if const_expr(self.check_hdim_v_oob) else None, + ) + + +class FlashAttentionForwardSm80(FlashAttentionForwardBase): + def _get_smem_layout_atom(self): + sQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.tile_hdim) + sK_layout_atom = sQ_layout_atom + sV_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.tile_hdimv) + sO_layout_atom = sV_layout_atom + sP_layout_atom = None + return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom + + def _get_tiled_mma(self): + tiled_mma_qk = cute.make_tiled_mma( + warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)), + (self.num_threads // 32, 1, 1), + permutation_mnk=(self.num_threads // 32 * 16, 16, 16), + ) + tiled_mma_pv = cute.make_tiled_mma( + warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)), + (self.num_threads // 32, 1, 1), + permutation_mnk=(self.num_threads // 32 * 16, 16, 16), + ) + return tiled_mma_qk, tiled_mma_pv + + def _get_shared_storage_cls(self): + sQ_struct, sK_struct, sV_struct = [ + cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 1024] + for layout in (self.sQ_layout, self.sK_layout, self.sV_layout) + ] + cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) + sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024] + + @cute.struct + class SharedStorageQKV: + sV: sV_struct + sQ: sQ_struct + sK: sK_struct + + @cute.struct + class SharedStorageSharedQV: + sQ: sQV_struct + sK: sK_struct + + return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + stream: cuda.CUstream, + softmax_scale: Optional[Float32] = None, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + learnable_sink: Optional[cute.Tensor] = None, + aux_tensors=None, + ): + """Configures and launches the flash attention kernel. + + mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: + (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) + """ + assert learnable_sink is None, "Learnable sink is not supported in this kernel" + self._check_type( + *(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE)) + ) + tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma() + self.num_mma_threads = tiled_mma_pv.size + self.num_producer_threads = self.num_threads + self.num_Q_load_threads = self.num_threads + self.num_epilogue_threads = self.num_threads + # self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None + self.use_tma_O = self.arch >= 90 + self._setup_attributes() + SharedStorage = self._get_shared_storage_cls() + # Assume all strides are divisible by 128 bits except the last stride + # Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints) + new_stride = lambda t: ( + *( + cute.assume(s, divby=128 // t.element_type.width) + if not isinstance(s, int) or s != 0 + else s + for s in t.stride[:-1] + ), + t.stride[-1], + ) + mQ, mK, mV, mO = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + for t in (mQ, mK, mV, mO) + ] + mQ, mK, mV, mO = [ + cute.make_tensor(t.iterator, cute.select(t.layout, mode=[1, 3, 2, 0])) + for t in (mQ, mK, mV, mO) + ] + mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=[2, 1, 0])) + # grid_dim: (m_block, num_head, batch_size) + grid_dim = ( + cute.ceil_div(mQ.shape[0], self.tile_m), + cute.size(mQ.shape[2]), + cute.size(mQ.shape[3]), + ) + LOG2_E = math.log2(math.e) + if const_expr(self.score_mod is None): + softmax_scale_log2 = Float32(softmax_scale * LOG2_E) + softmax_scale = None + else: + # NB: If a user passes in a score mod, we want to apply the score-mod in the sm_scaled qk + # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base + # and correctly apply the softmax_scale prior to score_mod in the softmax step + softmax_scale_log2 = Float32(LOG2_E) + softmax_scale = Float32(softmax_scale) + + fastdiv_mods = None + if const_expr(aux_tensors is not None): + seqlen_q = cute.size(mQ.shape[0]) // ( + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 + ) + seqlen_k = cute.size(mK.shape[0]) + seqlen_q_divmod = FastDivmodDivisor(seqlen_q) + seqlen_k_divmod = FastDivmodDivisor(seqlen_k) + fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) + + self.kernel( + mQ, + mK, + mV, + mO, + mLSE, + softmax_scale_log2, + softmax_scale, + window_size_left, + window_size_right, + self.sQ_layout, + self.sK_layout, + self.sV_layout, + self.sO_layout, + self.sP_layout, + self.gmem_tiled_copy_Q, + self.gmem_tiled_copy_K, + self.gmem_tiled_copy_V, + self.gmem_tiled_copy_O, + tiled_mma_qk, + tiled_mma_pv, + SharedStorage, + aux_tensors, + fastdiv_mods, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + smem=SharedStorage.size_in_bytes(), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + softmax_scale_log2: Float32, + softmax_scale: Optional[Float32], + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sO_layout: cute.ComposedLayout, + sP_layout: cute.ComposedLayout | None, + gmem_tiled_copy_Q: cute.TiledCopy, + gmem_tiled_copy_K: cute.TiledCopy, + gmem_tiled_copy_V: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + SharedStorage: cutlass.Constexpr, + aux_tensors=None, + fastdiv_mods=None, + ): + # Thread index, block index + tidx, _, _ = cute.arch.thread_idx() + m_block, num_head, batch_size = cute.arch.block_idx() + + block_info = BlockInfo( + self.tile_m, + self.tile_n, + self.is_causal, + self.is_local, + False, # is_split_kv + window_size_left, + window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + seqlen = SeqlenInfoQK.create(seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0]) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + # TODO: return early if n_block_max == 0 + # if self.is_causal: + # if n_block_max <= 0: + # return + n_block = n_block_max - 1 + + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # /////////////////////////////////////////////////////////////////////////////// + blkQ_shape = (self.tile_m, self.tile_hdim) + blkK_shape = (self.tile_n, self.tile_hdim) + blkV_shape = (self.tile_n, self.tile_hdimv) + gQ = cute.local_tile(mQ[None, None, num_head, batch_size], blkQ_shape, (m_block, 0)) + num_head_kv = num_head // self.qhead_per_kvhead + gK = cute.local_tile(mK[None, None, num_head_kv, batch_size], blkK_shape, (None, 0)) + gV = cute.local_tile(mV[None, None, num_head_kv, batch_size], blkV_shape, (None, 0)) + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sQ = storage.sQ.get_tensor(sQ_layout) + sK = storage.sK.get_tensor(sK_layout) + if const_expr(not self.Q_in_regs): + sV = storage.sV.get_tensor(sV_layout) + else: + sV = cute.make_tensor(cute.recast_ptr(sQ.iterator, dtype=self.dtype), sV_layout) + # Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma + sVt = utils.transpose_view(sV) + + gmem_thr_copy_K = gmem_tiled_copy_K.get_slice(tidx) + gmem_thr_copy_V = gmem_tiled_copy_V.get_slice(tidx) + # (CPY_Atom, CPY_N, CPY_K, n_block) + tKsK, tKgK = gmem_thr_copy_K.partition_D(sK), gmem_thr_copy_K.partition_S(gK) + # (CPY_Atom, CPY_N, CPY_K, n_block) + tVsV, tVgV = gmem_thr_copy_V.partition_D(sV), gmem_thr_copy_V.partition_S(gV) + + # /////////////////////////////////////////////////////////////////////////////// + # Tile MMA compute thread partitions and allocate accumulators + # /////////////////////////////////////////////////////////////////////////////// + thr_mma_qk = tiled_mma_qk.get_slice(tidx) + thr_mma_pv = tiled_mma_pv.get_slice(tidx) + tSrQ = thr_mma_qk.make_fragment_A(thr_mma_qk.partition_A(sQ)) + tSrK = thr_mma_qk.make_fragment_B(thr_mma_qk.partition_B(sK[None, None, 0])) + tOrVt = thr_mma_pv.make_fragment_B(thr_mma_pv.partition_B(sVt[None, None, 0])) + acc_shape_O = thr_mma_pv.partition_shape_C((self.tile_m, self.tile_hdimv)) + acc_O = cute.make_fragment(acc_shape_O, Float32) + acc_O.fill(0.0) + + # /////////////////////////////////////////////////////////////////////////////// + # Smem copy atom tiling + # /////////////////////////////////////////////////////////////////////////////// + smem_copy_atom_QK = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), + self.dtype, + ) + smem_copy_atom_V = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), + self.dtype, + ) + smem_thr_copy_Q = utils.make_tiled_copy_A(smem_copy_atom_QK, tiled_mma_qk).get_slice(tidx) + smem_thr_copy_K = utils.make_tiled_copy_B(smem_copy_atom_QK, tiled_mma_qk).get_slice(tidx) + smem_thr_copy_V = utils.make_tiled_copy_B(smem_copy_atom_V, tiled_mma_pv).get_slice(tidx) + + tSsQ = smem_thr_copy_Q.partition_S(sQ) + tSsK = smem_thr_copy_K.partition_S(sK) + tOsVt = smem_thr_copy_V.partition_S(sVt) + + # /////////////////////////////////////////////////////////////////////////////// + # Predicate: Mark indices that need to copy when problem_shape isn't a multiple + # of tile_shape + # /////////////////////////////////////////////////////////////////////////////// + # Construct identity layout for KV + cK = cute.make_identity_tensor((self.tile_n, self.tile_hdim)) + tKcK = gmem_thr_copy_K.partition_S(cK) + t0KcK = gmem_thr_copy_K.get_slice(0).partition_S(cK) + if const_expr(self.tile_hdim == self.tile_hdimv): + tVcV = tKcK + t0VcV = t0KcK + else: + cV = cute.make_identity_tensor((self.tile_n, self.tile_hdimv)) + tVcV = gmem_thr_copy_V.partition_S(cV) + t0VcV = gmem_thr_copy_V.get_slice(0).partition_S(cV) + # Allocate predicate tensors for m and n, here we only allocate the tile of k, and + # use "if" on the mn dimension. + # This is to reduce register pressure and gets 2-3% performance gain. + tKpK = utils.predicate_k(tKcK, limit=mK.shape[1]) + if const_expr(self.same_hdim_kv): + tVpV = tKpK + else: + tVpV = utils.predicate_k(tVcV, limit=mV.shape[1]) + + # shape: (atom_v_m * rest_m) + softmax = Softmax.create( + softmax_scale_log2, + num_rows=acc_O.shape[0][0] * acc_O.shape[1], + softmax_scale=softmax_scale, + ) + softmax.reset() + + # group parameters for compute_one_n_block + mma_params = SimpleNamespace( + thr_mma_qk=thr_mma_qk, + thr_mma_pv=thr_mma_pv, + tSrQ=tSrQ, + tSrK=tSrK, + tOrVt=tOrVt, + acc_O=acc_O, + ) + smem_copy_params = SimpleNamespace( + smem_thr_copy_Q=smem_thr_copy_Q, + smem_thr_copy_K=smem_thr_copy_K, + smem_thr_copy_V=smem_thr_copy_V, + tSsQ=tSsQ, + tSsK=tSsK, + tOsVt=tOsVt, + ) + load_K = partial( + self.load_K, gmem_tiled_copy_K, tKgK, tKsK, tKcK, t0KcK, tKpK, seqlen=seqlen.seqlen_k + ) + load_V = partial( + self.load_V, gmem_tiled_copy_V, tVgV, tVsV, tVcV, t0VcV, tVpV, seqlen=seqlen.seqlen_k + ) + + compute_one_n_block = partial( + self.compute_one_n_block, + mma_params=mma_params, + smem_copy_params=smem_copy_params, + softmax=softmax, + load_K=load_K, + load_V=load_V, + score_mod=self.score_mod, + batch_idx=batch_size, + head_idx=num_head, + m_block=m_block, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Prologue + # /////////////////////////////////////////////////////////////////////////////// + # Start async loads of the last mn-tile, where we take care of the mn residue + gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx) + self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, headdim=mQ.shape[1]) + cute.arch.cp_async_commit_group() + + def preprocess_Q(): + cute.arch.cp_async_wait_group(self.num_stages * 2 - 1) + if const_expr(self.Q_in_regs): + cute.arch.barrier() + tSrQ_copy_view = smem_thr_copy_Q.retile(tSrQ) + cute.copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view) + + # If Q_in_regs, we load Q, then load 1 stage of K, then (optionally) rotate Q and + # read from smem_q to registers, then load V. + # If !Q_in_regs, we load Q, load all stages of K & V, then (optionally) rotate Q. + if const_expr(self.Q_in_regs): + load_K(n_block, smem_pipe_write=0, need_predicates=True) + cute.arch.cp_async_commit_group() + preprocess_Q() + cute.arch.barrier() # Make sure all threads have read smem_q before loading V + + for stage in cutlass.range_constexpr(self.num_stages): + if const_expr(not self.Q_in_regs or stage > 0): + if stage == 0 or n_block - stage >= 0: + load_K(n_block - stage, smem_pipe_write=stage, need_predicates=stage == 0) + cute.arch.cp_async_commit_group() + if const_expr(stage < self.num_stages - 1): + if stage == 0 or n_block - stage >= 0: + load_V(n_block - stage, smem_pipe_write=stage, need_predicates=stage == 0) + cute.arch.cp_async_commit_group() + if const_expr(not self.Q_in_regs): + preprocess_Q() + + # /////////////////////////////////////////////////////////////////////////////// + # Mainloop + # /////////////////////////////////////////////////////////////////////////////// + # Start processing of the first n-block. + # For performance reason, we separate out two kinds of iterations: + # those that need masking on S, and those that don't. + # We need masking on S for the very last block when K and V has length not multiple of tile_n. + # We also need masking on S if it's causal, for the last several blocks. + mask = AttentionMask( + self.tile_m, + self.tile_n, + seqlen.seqlen_q, + seqlen.seqlen_k, + window_size_left, + window_size_right, + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + mask_fn = partial( + mask.apply_mask, + m_block=m_block, + thr_mma=thr_mma_qk, + mask_causal=self.is_causal, + mask_local=self.is_local, + fastdiv_mods=fastdiv_mods if const_expr(self.mask_mod is not None) else None, + ) + + # First iteration with seqlen masking + smem_pipe_read = Int32(0) + smem_pipe_write = Int32(self.num_stages - 1) + compute_one_n_block( + n_block, + smem_pipe_read, + smem_pipe_write, + is_first_n_block=True, + check_inf=True, + mask_fn=partial(mask_fn, mask_seqlen=True), + ) + smem_pipe_read = self.advance_pipeline(smem_pipe_read) + smem_pipe_write = self.advance_pipeline(smem_pipe_write) + # Next couple of iterations with causal masking + if const_expr(self.is_causal or self.is_local): + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen, m_block, n_block_min + ) + for n_tile in cutlass.range(n_block_max - 1 - n_block_min_causal_local_mask, unroll=1): + n_block = n_block_max - 2 - n_tile + compute_one_n_block( + n_block, + smem_pipe_read, + smem_pipe_write, + check_inf=True, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + smem_pipe_read = self.advance_pipeline(smem_pipe_read) + smem_pipe_write = self.advance_pipeline(smem_pipe_write) + # The remaining iterations have no masking + for n_tile in cutlass.range(n_block, unroll=1): + compute_one_n_block( + n_block - n_tile - 1, smem_pipe_read, smem_pipe_write, check_inf=True + ) + smem_pipe_read = self.advance_pipeline(smem_pipe_read) + smem_pipe_write = self.advance_pipeline(smem_pipe_write) + # TODO: local + + # normalize acc_O by row_sum and calculate the lse + row_scale = softmax.finalize() + softmax.rescale_O(acc_O, row_scale) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + # reuse sQ's data iterator + sO = cute.make_tensor(sQ.iterator, sO_layout) + self.epilogue( + acc_O, + softmax.row_sum, + mO, + mLSE, + sO, + seqlen, + gmem_tiled_copy_O, + None, + tiled_mma_pv, + tidx, + m_block, + num_head, + batch_size, + ) + + @cute.jit + def compute_one_n_block( + self, + n_block: Int32, + smem_pipe_read: Int32, + smem_pipe_write: Int32, + mma_params: SimpleNamespace, + smem_copy_params: SimpleNamespace, + softmax: Softmax, + load_K: Callable, + load_V: Callable, + score_mod: Callable | None, + batch_idx: cutlass.Int32, + head_idx: cutlass.Int32, + m_block: cutlass.Int32, + seqlen: SeqlenInfoQK, + aux_tensors=None, + fastdiv_mods=None, + mask_fn: Optional[Callable] = None, + is_first_n_block: cutlass.Constexpr = False, + check_inf: cutlass.Constexpr = True, + ): + """Compute one n_block of S/O. + + This function provides different variants for processing the first n block versus + subsequent blocks. + """ + + def sync(): + cute.arch.cp_async_wait_group(self.num_stages * 2 - 2) + cute.arch.barrier() + + acc_shape_S = mma_params.thr_mma_qk.partition_shape_C((self.tile_m, self.tile_n)) + acc_S = cute.make_fragment(acc_shape_S, Float32) + acc_S.fill(0.0) + # wait for smem tile QK before mma calculation for S + sync() + + # need predicates for the first tile + def load_V_next(): + if self.num_stages == 1 or n_block - self.num_stages + 1 >= 0: + load_V( + n_block - self.num_stages + 1, + smem_pipe_write, + need_predicates=is_first_n_block and self.num_stages == 1, + ) + cute.arch.cp_async_commit_group() + + load_V_next() + sm80_utils.gemm( + mma_params.thr_mma_qk, + acc_S, + mma_params.tSrQ, + mma_params.tSrK, + smem_copy_params.tSsQ, + smem_copy_params.tSsK[ + None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0 + ], + smem_copy_params.smem_thr_copy_Q, + smem_copy_params.smem_thr_copy_K, + # hook_fn=load_V_next, + A_in_regs=self.Q_in_regs, + ) + if const_expr(score_mod is not None): + self.apply_score_mod( + mma_params.thr_mma_qk, + batch_idx, + head_idx, + m_block, + acc_S, + n_block, + seqlen, + softmax_scale=softmax.softmax_scale, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + + smem_pipe_write = self.advance_pipeline(smem_pipe_write) + + def load_K_next(): + if n_block - self.num_stages >= 0: + load_K(n_block - self.num_stages, smem_pipe_write, need_predicates=False) + cute.arch.cp_async_commit_group() + + # wait for smem tile V for O + if const_expr(self.num_stages == 1): + sync() + load_K_next() + if const_expr(mask_fn is not None): + mask_fn(acc_S, n_block=n_block) + row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) + softmax.rescale_O(mma_params.acc_O, row_scale) + rP = cute.make_fragment_like(acc_S, self.dtype) + rP.store(acc_S.load().to(self.dtype)) + tOrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout)) + if const_expr(self.num_stages > 1): + sync() + load_K_next() + sm80_utils.gemm_rs( + mma_params.thr_mma_pv, + mma_params.acc_O, + tOrP, + mma_params.tOrVt, + smem_copy_params.tOsVt[ + None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0 + ], + smem_copy_params.smem_thr_copy_V, + # hook_fn=load_K_next, + ) + # if const_expr(self.num_stages > 1): + # load_K_next() + + +class FlashAttentionForwardSm90(FlashAttentionForwardBase): + arch = 90 + + def __init__( + self, + *args, + intra_wg_overlap: bool = True, + mma_pv_is_rs: bool = True, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.intra_wg_overlap = intra_wg_overlap + self.mma_pv_is_rs = mma_pv_is_rs + self.buffer_align_bytes = 1024 + + def _get_smem_layout_atom(self): + sQ_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_basic.get_smem_layout_atom(LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdim), + self.dtype, + ) + sK_layout_atom = sQ_layout_atom + sV_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_basic.get_smem_layout_atom( + LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdimv + ), + self.dtype, + ) + sO_layout_atom = sV_layout_atom + if not self.mma_pv_is_rs: + sP_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_basic.get_smem_layout_atom( + LayoutEnum.ROW_MAJOR, self.dtype, self.tile_n + ), + self.dtype, + ) + else: + sP_layout_atom = None + return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom + + def _get_tiled_mma(self): + tiled_mma_qk = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K, + warpgroup.OperandMajorMode.K, + Float32, + atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 + tiler_mn=(64, self.tile_n), + ) + tiled_mma_pv = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K, + warpgroup.OperandMajorMode.MN, + Float32, + atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 + tiler_mn=(64, self.tile_hdimv), + a_source=warpgroup.OperandSource.RMEM + if self.mma_pv_is_rs + else warpgroup.OperandSource.SMEM, + ) + tiled_mma_pv_rs = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K, + warpgroup.OperandMajorMode.MN, + Float32, + atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 + tiler_mn=(64, self.tile_hdimv), + a_source=warpgroup.OperandSource.RMEM, + ) + return tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs + + def _get_shared_storage_cls(self): + # If we use cp.async to load Q, we want sQ to align to 1024 bytes + sQ_struct, sK_struct, sV_struct = [ + cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], self.buffer_align_bytes] + for layout in (self.sQ_layout, self.sK_layout, self.sV_layout) + + ] + cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) + sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024] + cosize_sP = cute.cosize(self.sP_layout) if const_expr(self.sP_layout is not None) else 0 + sP_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024] + # 1 for Q, 1 for O, self.num_stages*2 for K, self.num_stages*2 for V, + mbar_ptr_QO_struct = cute.struct.MemRange[cutlass.Int64, 2] + mbar_ptr_K_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] + mbar_ptr_V_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] + + @cute.struct + class SharedStorageQKV: + mbar_ptr: mbar_ptr_QO_struct + mbar_ptr_K: mbar_ptr_K_struct + mbar_ptr_V: mbar_ptr_V_struct + sV: sV_struct + sQ: sQ_struct + sK: sK_struct + sP: sP_struct + + @cute.struct + class SharedStorageSharedQV: + mbar_ptr: mbar_ptr_QO_struct + mbar_ptr_K: mbar_ptr_K_struct + mbar_ptr_V: mbar_ptr_V_struct + sQ: sQV_struct + sK: sK_struct + sP: sP_struct + + return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table + mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table + mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + mLSE: Optional[cute.Tensor], + softmax_scale: Float32, + stream: cuda.CUstream, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq) + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, + learnable_sink: Optional[cute.Tensor] = None, + blocksparse_tensors: Optional[BlockSparseTensors] = None, + aux_tensors: Optional[list] = None, + ): + """Configures and launches the flash attention kernel. + + mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: + (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) + """ + + self._check_type( + *( + t.element_type if t is not None else None + for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK) + ) + ) + + # Assume all strides are divisible by 128 bits except the last stride + # Skip cute.assume() for stride=0 (broadcast dims from expand() are Python ints) + new_stride = lambda t: ( + *( + cute.assume(s, divby=128 // t.element_type.width) + if not isinstance(s, int) or s != 0 + else s + for s in t.stride[:-1] + ), + t.stride[-1], + ) + + mQ, mK, mV, mO = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + for t in (mQ, mK, mV, mO) + ] + QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + mQ, mO = [utils.select(t, QO_layout_transpose) for t in (mQ, mO)] + KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] + mK, mV = [utils.select(t, KV_layout_transpose) for t in (mK, mV)] + LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + mLSE = utils.select(mLSE, LSE_layout_transpose) if const_expr(mLSE is not None) else None + + tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs = self._get_tiled_mma() + self.num_mma_threads = tiled_mma_qk.size + self.num_threads_per_warp_group = 128 + self.num_mma_warp_groups = self.num_mma_threads // self.num_threads_per_warp_group + self.num_threads = self.num_threads_per_warp_group * (self.num_mma_warp_groups + 1) + self.num_producer_threads = 32 + self.num_Q_load_threads = self.num_mma_threads # If not TMA_Q, MMA threads load Q + self.num_epilogue_threads = self.num_mma_threads + self.num_mma_regs = ( + 256 + if self.num_mma_warp_groups == 1 + else (240 if self.num_mma_warp_groups == 2 else 160) + ) + self.num_producer_regs = ( + 56 if self.num_mma_warp_groups == 1 else (24 if self.num_mma_warp_groups == 2 else 32) + ) + # self.num_mma_regs = 232 + # self.num_producer_regs = 40 + self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) + + self.use_scheduler_barrier = ( + (self.num_mma_warp_groups >= 2 and self.tile_hdim <= 128) + if const_expr(self.intra_wg_overlap) + else (self.num_mma_warp_groups == 2) + ) + self.use_tma_Q = self.arch >= 90 and not ( + self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0 + ) + self.use_tma_O = ( + self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa + ) + # TODO: rescale_O_before_gemm + self._setup_attributes() + # TODO: we prob don't need most of what's in _setup_attributes + self.sQ_layout, self.sK_layout, self.sV_layout, self.sO_layout = [ + sm90_utils.make_smem_layout(mX.element_type, LayoutEnum.ROW_MAJOR, shape, stage) + for mX, shape, stage in [ + (mQ, (self.tile_m, self.tile_hdim), None), + (mK, (self.tile_n, self.tile_hdim), self.num_stages), + (mV, (self.tile_n, self.tile_hdimv), self.num_stages), + (mO, (self.tile_m, self.tile_hdimv), None), + ] + ] + self.sP_layout = None + if const_expr(not self.mma_pv_is_rs): + self.sP_layout = sm90_utils.make_smem_layout( + mV.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_n) + ) + + SharedStorage = self._get_shared_storage_cls() + + if const_expr(self.pack_gqa): + shape_Q_packed = ( + (self.qhead_per_kvhead, mQ.shape[0]), + mQ.shape[1], + mK.shape[2], + *mQ.shape[3:], + ) + stride_Q_packed = ( + (mQ.stride[2], mQ.stride[0]), + mQ.stride[1], + mQ.stride[2] * self.qhead_per_kvhead, + *mQ.stride[3:], + ) + mQ = cute.make_tensor( + mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed) + ) + shape_O_packed = ( + (self.qhead_per_kvhead, mO.shape[0]), + mK.shape[1], + mK.shape[2], + *mO.shape[3:], + ) + stride_O_packed = ( + (mO.stride[2], mO.stride[0]), + mO.stride[1], + mO.stride[2] * self.qhead_per_kvhead, + *mO.stride[3:], + ) + mO = cute.make_tensor( + mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed) + ) + if const_expr(mLSE is not None): + shape_LSE_packed = ( + (self.qhead_per_kvhead, mLSE.shape[0]), + mK.shape[2], + *mLSE.shape[2:], + ) + stride_LSE_packed = ( + (mLSE.stride[1], mLSE.stride[0]), + mLSE.stride[1] * self.qhead_per_kvhead, + *mLSE.stride[2:], + ) + mLSE = cute.make_tensor( + mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed) + ) + + # TMA + gmem_tiled_copy_Q = cpasync.CopyBulkTensorTileG2SOp() + gmem_tiled_copy_KV = cpasync.CopyBulkTensorTileG2SOp() # Might multicast + gmem_tiled_copy_O = cpasync.CopyBulkTensorTileS2GOp() + self.tma_copy_bytes = { + name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1])) + for name, mX, layout in [ + ("Q", mQ, self.sQ_layout), + ("K", mK, self.sK_layout), + ("V", mV, self.sV_layout), + ] + } + tma_atom_Q, tma_tensor_Q = None, None + if const_expr(self.use_tma_Q): + tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom( + gmem_tiled_copy_Q, + mQ, + self.sQ_layout, + (self.tile_m, self.tile_hdim), # No mcast + ) + tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom( + gmem_tiled_copy_KV, + mK, + cute.select(self.sK_layout, mode=[0, 1]), + (self.tile_n, self.tile_hdim), + 1, # No mcast for now + ) + tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom( + gmem_tiled_copy_KV, + mV, + cute.select(self.sV_layout, mode=[0, 1]), + (self.tile_n, self.tile_hdimv), + 1, # No mcast for now + ) + tma_atom_O, tma_tensor_O = None, None + if const_expr(self.use_tma_O): + tma_atom_O, tma_tensor_O = cpasync.make_tiled_tma_atom( + gmem_tiled_copy_O, + mO, + self.sO_layout, + (self.tile_m, self.tile_hdimv), # No mcast + ) + if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): + TileScheduler = SingleTileVarlenScheduler + else: + TileScheduler = ( + SingleTileScheduler + if const_expr(not self.is_causal or self.is_local) + else SingleTileLPTScheduler + ) + tile_sched_args = TileSchedulerArguments( + cute.ceil_div(cute.size(mQ.shape[0]), self.tile_m), + cute.size(mQ.shape[2]), + cute.size(mQ.shape[3]) + if const_expr(mCuSeqlensQ is None) + else cute.size(mCuSeqlensQ.shape[0] - 1), + 1, # num_splits + cute.size(mK.shape[0]), + mQ.shape[1], + mV.shape[1], + total_q=cute.size(mQ.shape[0]) + if const_expr(mCuSeqlensQ is not None) + else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + tile_shape_mn=(self.tile_m, self.tile_n), + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + element_size=self.dtype.width // 8, + is_persistent=False, + lpt=self.is_causal or self.is_local, + ) + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + LOG2_E = math.log2(math.e) + if const_expr(self.score_mod is None): + softmax_scale_log2 = softmax_scale * LOG2_E + softmax_scale = None + else: + # NB: If a user passes in a score mod, we want to apply the score-mod in the sm_scaled qk + # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base + # and correctly apply the softmax_scale prior to score_mod in the softmax step + softmax_scale_log2 = LOG2_E + softmax_scale = softmax_scale + if const_expr(window_size_left is not None): + window_size_left = Int32(window_size_left) + if const_expr(window_size_right is not None): + window_size_right = Int32(window_size_right) + + fastdiv_mods = None + if const_expr(aux_tensors is not None): + seqlen_q = cute.size(mQ.shape[0]) // ( + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 + ) + seqlen_k = ( + cute.size(mK.shape[0]) + if const_expr(mPageTable is None) + else mK.shape[0] * mPageTable.shape[1] + ) + seqlen_q_divmod = FastDivmodDivisor(seqlen_q) + seqlen_k_divmod = FastDivmodDivisor(seqlen_k) + fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) + + self.kernel( + tma_tensor_Q if const_expr(self.use_tma_Q) else mQ, + tma_tensor_K, + tma_tensor_V, + tma_tensor_O if const_expr(self.use_tma_O) else mO, + mLSE, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_O, + softmax_scale_log2, + softmax_scale, + window_size_left, + window_size_right, + learnable_sink, + blocksparse_tensors, + self.sQ_layout, + self.sK_layout, + self.sV_layout, + self.sO_layout, + self.sP_layout, + self.gmem_tiled_copy_Q, + self.gmem_tiled_copy_K, + self.gmem_tiled_copy_V, + self.gmem_tiled_copy_O, + tiled_mma_qk, + tiled_mma_pv, + tiled_mma_pv_rs, + tile_sched_params, + TileScheduler, + SharedStorage, + aux_tensors, + fastdiv_mods, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], + tma_atom_Q: Optional[cute.CopyAtom], + tma_atom_K: Optional[cute.CopyAtom], + tma_atom_V: Optional[cute.CopyAtom], + tma_atom_O: Optional[cute.CopyAtom], + softmax_scale_log2: Float32, + softmax_scale: Optional[Float32], + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + learnable_sink: Optional[cute.Tensor], + blocksparse_tensors: Optional[BlockSparseTensors], + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sO_layout: cute.ComposedLayout, + sP_layout: cute.ComposedLayout | None, + gmem_tiled_copy_Q: cute.TiledCopy, + gmem_tiled_copy_K: cute.TiledCopy, + gmem_tiled_copy_V: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tiled_mma_pv_rs: cute.TiledMma, + tile_sched_params: ParamsBase, + TileScheduler: cutlass.Constexpr[Callable], + SharedStorage: cutlass.Constexpr[Callable], + aux_tensors=Optional[list[cute.Tensor]], + fastdiv_mods=None, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + # Prefetch tma descriptor + if warp_idx == 0: + for tma_atom in (tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O): + if const_expr(tma_atom is not None): + cpasync.prefetch_descriptor(tma_atom) + + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + # Mbarrier init + mbar_ptr_Q = storage.mbar_ptr.data_ptr() + if warp_idx == 1: + # if tidx < 2: + # # barrierO num threads should be self.num_mma_threads + # cute.arch.mbarrier_init(mbar_ptr_Q + tidx, 1 if tidx == 0 else self.num_mma_threads) + if const_expr(not self.use_tma_Q): + cute.arch.mbarrier_init(mbar_ptr_Q, self.num_Q_load_threads) + # cute.arch.mbarrier_init(mbar_ptr_Q + 1, self.num_mma_threads) + # We rely on pipeline_k and pipeline_v to initialize the mbarrier fence and sync + pipeline_kv_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread + ) + pipeline_kv_consumer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, self.num_mma_threads // cute.arch.WARP_SIZE + ) + pipeline_k = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.mbar_ptr_K.data_ptr(), + num_stages=self.num_stages, + producer_group=pipeline_kv_producer_group, + consumer_group=pipeline_kv_consumer_group, + tx_count=self.tma_copy_bytes["K"], + defer_sync=True, + ) + pipeline_v = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.mbar_ptr_V.data_ptr(), + num_stages=self.num_stages, + producer_group=pipeline_kv_producer_group, + consumer_group=pipeline_kv_consumer_group, + tx_count=self.tma_copy_bytes["V"], + defer_sync=False + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + if const_expr(not self.Q_in_regs): + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + else: + sV = storage.sQ.get_tensor( + sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type + ) + # Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma + sVt = utils.transpose_view(sV) + sP = None + if const_expr(sP_layout is not None): + sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner) + # reuse sQ's data iterator + sO = storage.sQ.get_tensor(sO_layout.outer, swizzle=sO_layout.inner, dtype=self.dtype) + + block_info = BlockInfo( + self.tile_m, + self.tile_n, + self.is_causal, + self.is_local, + False, # is_split_kv + window_size_left, + window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + SeqlenInfoCls = partial( + SeqlenInfoQK.create, + seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], + seqlen_k_static=mK.shape[0], + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, + ) + AttentionMaskCls = partial( + AttentionMask, + self.tile_m, + self.tile_n, + window_size_left=window_size_left, + window_size_right=window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + TileSchedulerCls = partial(TileScheduler.create, tile_sched_params) + + if warp_idx < 4: # Producer + cute.arch.warpgroup_reg_dealloc(self.num_producer_regs) + self.load( + mQ, + mK, + mV, + sQ, + sK, + sV, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + pipeline_k, + pipeline_v, + mbar_ptr_Q, + blocksparse_tensors, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + ) + + else: # Consumer + cute.arch.warpgroup_reg_alloc(self.num_mma_regs) + # /////////////////////////////////////////////////////////////////////////////// + # Tile MMA compute thread partitions and allocate accumulators + # /////////////////////////////////////////////////////////////////////////////// + tidx, _, _ = cute.arch.thread_idx() + tidx = tidx - 128 + self.mma( + tiled_mma_qk, + tiled_mma_pv, + tiled_mma_pv_rs, + mQ, + mO, + mLSE, + sQ, + sK, + sVt, + sP, + sO, + learnable_sink, + pipeline_k, + pipeline_v, + mbar_ptr_Q, + gmem_tiled_copy_Q, + gmem_tiled_copy_O, + tma_atom_O, + tidx, + softmax_scale_log2, + softmax_scale, + block_info, + SeqlenInfoCls, + AttentionMaskCls, + TileSchedulerCls, + blocksparse_tensors, + aux_tensors, + fastdiv_mods, + ) + + @cute.jit + def load( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + pipeline_k: cutlass.pipeline.PipelineAsync, + pipeline_v: cutlass.pipeline.PipelineAsync, + mbar_ptr_Q: cutlass.Pointer, + blocksparse_tensors: Optional[BlockSparseTensors], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 + if warp_idx_in_wg == 0: + q_producer_phase = Int32(1) + kv_producer_state = pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.num_stages + ) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + # if work_tile.is_valid_tile: + m_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] + head_idx_kv = ( + head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + ) + mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] + mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv] + gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (None, 0)) + gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (None, 0)) + if const_expr(self.use_tma_Q): + gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) + load_Q, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), gQ, sQ, single_stage=True + ) + # TODO: mcast + # TODO check warp_idx if we have 128 producer threads + load_K, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_K, 0, cute.make_layout(1), gK, sK + ) + load_K = copy_utils.tma_producer_copy_fn(load_K, pipeline_k) + load_V, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_V, 0, cute.make_layout(1), gV, sV + ) + load_V = copy_utils.tma_producer_copy_fn(load_V, pipeline_v) + + if const_expr(not self.use_block_sparsity): + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + # if cute.arch.thread_idx()[0] == 0: + # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) + # First iteration: load both Q & K with the same mbarrier + n_block = n_block_max - 1 + pipeline_k.producer_acquire( + kv_producer_state, + extra_tx_count=self.tma_copy_bytes["Q"] + if const_expr(self.use_tma_Q) + else 0, + ) + if const_expr(self.use_tma_Q): + load_Q(tma_bar_ptr=pipeline_k.producer_get_barrier(kv_producer_state)) + load_K(src_idx=n_block, producer_state=kv_producer_state) + + if const_expr(not self.intra_wg_overlap): + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block, producer_state=kv_producer_state) + kv_producer_state.advance() + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block = n_block_max - 1 - i - 1 + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block, producer_state=kv_producer_state) + kv_producer_state.advance() + else: + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block_prev = n_block_max - i - 1 + n_block = n_block_prev - 1 + kv_producer_state_prev = kv_producer_state.clone() + kv_producer_state.advance() + pipeline_k.producer_acquire(kv_producer_state) + load_K(src_idx=n_block, producer_state=kv_producer_state) + pipeline_v.producer_acquire(kv_producer_state_prev) + load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev) + n_block = n_block_min + pipeline_v.producer_acquire(kv_producer_state) + load_V(src_idx=n_block, producer_state=kv_producer_state) + kv_producer_state.advance() + else: + kv_producer_state = produce_block_sparse_loads( + blocksparse_tensors, + batch_idx, + head_idx, + m_block, + kv_producer_state, + load_Q, + load_K, + load_V, + pipeline_k, + pipeline_v, + self.use_tma_Q, + self.tma_copy_bytes["Q"], + self.intra_wg_overlap, + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + @cute.jit + def mma( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tiled_mma_pv_rs: cute.TiledMma, + # softmax: Softmax, + # acc_O: cute.Tensor, + mQ: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + sQ: cute.Tensor, + sK: cute.Tensor, + sVt: cute.Tensor, + sP: Optional[cute.Tensor], + sO: cute.Tensor, + learnable_sink: Optional[cute.Tensor], + pipeline_k: cutlass.pipeline.PipelineAsync, + pipeline_v: cutlass.pipeline.PipelineAsync, + mbar_ptr_Q: cutlass.Pointer, + gmem_tiled_copy_Q: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + tma_atom_O: Optional[cute.CopyAtom], + tidx: Int32, + softmax_scale_log2: Float32, + softmax_scale: Optional[Float32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + AttentionMaskCls: Callable, + TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors], + aux_tensors: Optional[list], + fastdiv_mods=None, + ): + warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) + warp_group_thread_layout = cute.make_layout( + self.num_mma_warp_groups, stride=self.num_threads_per_warp_group + ) + thr_mma_qk = tiled_mma_qk.get_slice(tidx) + wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx)) + wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)) + tSrQ = tiled_mma_qk.make_fragment_A(wg_mma_qk.partition_A(sQ)) + tSrK = tiled_mma_qk.make_fragment_B(wg_mma_qk.partition_B(sK)) + if const_expr(self.mma_pv_is_rs): + acc_S_shape = tiled_mma_qk.partition_shape_C((self.tile_m, self.tile_n)) + tOrP = cute.make_fragment( + utils.convert_layout_acc_frgA(cute.make_layout(acc_S_shape)), self.dtype + ) + else: + tOrP = tiled_mma_pv.make_fragment_A(wg_mma_pv.partition_A(sP)) + tOrVt = tiled_mma_pv.make_fragment_B(wg_mma_pv.partition_B(sVt)) + + # /////////////////////////////////////////////////////////////////////////////// + # Smem copy atom tiling + # /////////////////////////////////////////////////////////////////////////////// + smem_copy_atom_P = utils.get_smem_store_atom(self.arch, self.dtype) + smem_thr_copy_P = cute.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx) + # tPsP = smem_thr_copy_P.partition_D(sP_pi) if const_expr(sP_pi is not None) else None + tPsP = smem_thr_copy_P.partition_D(sP) if const_expr(sP is not None) else None + # if cute.arch.thread_idx()[0] == 0: + # cute.printf(sP_pi.layout, sP_pi.iterator) + # cute.printf(sP.layout, sP.iterator) + # cute.printf(tPsP.layout, tPsP.iterator) + + self.mma_init() + + acc_shape_O = tiled_mma_pv.partition_shape_C((self.tile_m, self.tile_hdimv)) + acc_O = cute.make_fragment(acc_shape_O, Float32) + smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP) + + mma_qk_fn = partial( + sm90_utils.gemm_zero_init, tiled_mma_qk, (self.tile_m, self.tile_n), tSrQ, tSrK + ) + mma_pv_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_pv, acc_O, tOrP, tOrVt) + + mma_one_n_block_all = partial( + self.mma_one_n_block_intrawg_overlap + if const_expr(self.intra_wg_overlap) + else self.mma_one_n_block, + mma_qk_fn=mma_qk_fn, + tiled_mma_pv_rs=tiled_mma_pv_rs, + pipeline_k=pipeline_k, + pipeline_v=pipeline_v, + acc_O=acc_O, + tOrP=tOrP, + smem_copy_params=smem_copy_params, + check_inf=True, + ) + + q_consumer_phase = Int32(0) + kv_consumer_state = pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.num_stages + ) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + softmax = Softmax.create( + softmax_scale_log2, + num_rows=acc_O.shape[0][0] * acc_O.shape[1], + softmax_scale=softmax_scale, + ) + + process_first_half_block = partial( + self.first_half_block_overlap, + mma_qk_fn=mma_qk_fn, + pipeline_k=pipeline_k, + tOrP=tOrP, + smem_copy_params=smem_copy_params, + softmax=softmax, + ) + process_last_half_block = partial( + self.last_half_block_overlap, + pipeline_v=pipeline_v, + mma_pv_fn=mma_pv_fn, + ) + while work_tile.is_valid_tile: + # if work_tile.is_valid_tile: + + # shape: (atom_v_m * rest_m) + m_block, head_idx, batch_idx, _ = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + + # Recompute fastdiv_mods if necessary for varlen with aux_tensors + recompute_fastdiv_mods_q = cutlass.const_expr( + aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q) + ) + recompute_fastdiv_mods_k = cutlass.const_expr( + aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k) + ) + if cutlass.const_expr(fastdiv_mods is not None): + seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods + fastdiv_mods = ( + seqlen_q_divmod + if not recompute_fastdiv_mods_q + else FastDivmodDivisor(seqlen.seqlen_q), + seqlen_k_divmod + if not recompute_fastdiv_mods_k + else FastDivmodDivisor(seqlen.seqlen_k), + ) + + mask = AttentionMaskCls(seqlen) + mask_fn = partial( + mask.apply_mask, + batch_idx=batch_idx, + head_idx=head_idx, + m_block=m_block, + thr_mma=thr_mma_qk, + mask_causal=self.is_causal, + mask_local=self.is_local, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + score_mod_fn = None + if const_expr(self.score_mod is not None): + score_mod_fn = partial( + self.apply_score_mod, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + softmax_scale=softmax_scale, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + ) + mma_one_n_block = partial( + mma_one_n_block_all, + seqlen=seqlen, + softmax=softmax, + score_mod_fn=score_mod_fn, + ) + # Load Q if not TMA_Q + if const_expr(not self.use_tma_Q): + pack_gqa = PackGQA( + self.tile_m, self.tile_hdim, self.check_hdim_oob, self.qhead_per_kvhead + ) + mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] + # gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx) + # gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) + # self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, + # headdim=mQ.shape[1]) + pack_gqa.load_Q(mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q) + cute.arch.cp_async_mbarrier_arrive_noinc(mbar_ptr_Q) + + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + if const_expr(not self.use_tma_Q): + cute.arch.mbarrier_wait(mbar_ptr_Q, phase=q_consumer_phase) + q_consumer_phase ^= 1 + # For performance reason, we separate out two kinds of iterations: + # those that need masking on S, and those that don't. + # We need masking on S for the very last block when K and V has length not multiple of tile_n. + # We also need masking on S if it's causal, for the last several blocks. + # softmax.reset() # Don't need reset as we explicitly call softmax w is_first=True + O_should_accumulate = False + + # ========================================== + # MAINLOOP + # ========================================== + if const_expr(not self.use_block_sparsity): + # ========================================== + # No block-sparsity (original path) + # ========================================== + # First iteration with seqlen masking + if const_expr(self.intra_wg_overlap): + kv_consumer_state = process_first_half_block( + n_block=n_block_max - 1, + seqlen=seqlen, + kv_consumer_state=kv_consumer_state, + mask_fn=partial(mask_fn, mask_mod=self.mask_mod), + score_mod_fn=score_mod_fn, + is_first_block=True, + ) + # Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter + # acc_O.fill(0.0) + else: + self.warp_scheduler_barrier_sync() + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1, + seqlen=seqlen, + mma_pv_fn=partial(mma_pv_fn, zero_init=True), + is_first_n_block=True, + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True), + ) + O_should_accumulate = True + # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min) + n_block_max -= 1 + # Next couple of iterations with causal masking + if const_expr(self.is_causal or self.is_local): + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen, m_block, n_block_min + ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask) + for n_tile in cutlass.range( + n_block_max - n_block_min_causal_local_mask, unroll=1 + ): + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1 - n_tile, + seqlen=seqlen, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), + ) + O_should_accumulate = True + n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) + # The remaining iterations have no masking + n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( + seqlen, m_block, n_block_min + ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min) + for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1 - n_tile, + seqlen=seqlen, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), + ) + O_should_accumulate = True + # Separate iterations with local masking on the left + if const_expr(self.is_local and block_info.window_size_left is not None): + n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) + for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1): + kv_consumer_state = mma_one_n_block( + kv_consumer_state, + n_block=n_block_max - 1 - n_tile, + seqlen=seqlen, + mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), + mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), + ) + O_should_accumulate = True + # Last "half" iteration + if const_expr(self.intra_wg_overlap): + kv_consumer_state = process_last_half_block( + kv_consumer_state=kv_consumer_state, + zero_init=not O_should_accumulate, + ) + O_should_accumulate = True + else: + self.warp_scheduler_barrier_arrive() + + else: + # ========================================== + # Block sparsity + # ========================================== + kv_consumer_state, O_should_accumulate, processed_any = consume_block_sparse_loads( + blocksparse_tensors, + batch_idx, + head_idx, + m_block, + seqlen, + kv_consumer_state, + mma_pv_fn, + mma_one_n_block, + process_first_half_block, + process_last_half_block, + mask_fn, + score_mod_fn, + O_should_accumulate, + self.mask_mod, + fastdiv_mods, + self.intra_wg_overlap, + self.warp_scheduler_barrier_sync, + self.warp_scheduler_barrier_arrive, + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + + # Handle empty case (when no blocks to process) + if not processed_any: + softmax.reset() + acc_O.fill(0.0) + + sink_val = None + if const_expr(learnable_sink is not None): + if const_expr(not self.pack_gqa): + sink_val = Float32(learnable_sink[head_idx]) + else: # Each thread might have a different sink value due to different q_head + sink_val = cute.make_fragment_like(softmax.row_max, Float32) + cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) + tScS_mn = utils.make_acc_tensor_mn_view(thr_mma_qk.partition_C(cS)) + for r in cutlass.range(cute.size(sink_val), unroll_full=True): + row = m_block * self.tile_m + tScS_mn[r][0] + q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead + sink_val[r] = Float32(learnable_sink[q_head_idx]) + + # normalize acc_O by row_sum and calculate the lse + row_scale = softmax.finalize(sink_val=sink_val) + softmax.rescale_O(acc_O, row_scale) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + self.epilogue( + acc_O, + softmax.row_sum, + mO, + mLSE, + sO, + seqlen, + gmem_tiled_copy_O, + tma_atom_O, + tiled_mma_pv, + tidx, + m_block, + head_idx, + batch_idx, + ) + + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + + @cute.jit + def first_half_block_overlap( + self, + n_block: Int32, + mma_qk_fn: Callable, + kv_consumer_state, + pipeline_k, + tOrP: cute.Tensor, + smem_copy_params: SimpleNamespace, + softmax: Softmax, + seqlen: SeqlenInfoQK, + mask_fn: Callable = None, + score_mod_fn: Optional[Callable] = None, + is_first_block: bool = False, + ): + """Processes the first half block when using intra-warpgroup-overlap""" + + pipeline_k.consumer_wait(kv_consumer_state, pipeline_k.consumer_try_wait(kv_consumer_state)) + acc_S = mma_qk_fn(B_idx=kv_consumer_state.index, wg_wait=0) + pipeline_k.consumer_release(kv_consumer_state) + + # Apply score modification if present + if const_expr(score_mod_fn is not None): + score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) + + # Apply mask; mask_seqlen always True for first block + # Caveat: if full block further right than mask block, seqlen masking is redundant; + # however, masking is being applied anyway, so essentially no perf hit + mask_fn(acc_S, n_block=n_block, mask_seqlen=True) + + softmax.online_softmax(acc_S, is_first=is_first_block) + + tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) + tOrP_cur = ( + tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + ) + tOrP_cur.store(tOrP_acc.load().to(self.dtype)) + + # if pv gemm not rs + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) + cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) + # Fence and barrier to make smem store visible to WGMMA + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.sync_warp() + + return kv_consumer_state + + @cute.jit + def last_half_block_overlap( + self, + kv_consumer_state, + pipeline_v, + mma_pv_fn: Callable, + zero_init: bool, + ): + """Processes the final PV GEMM when using intra-warpgroup-overlap""" + + pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state)) + mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=zero_init, wg_wait=0) + pipeline_v.consumer_release(kv_consumer_state) + kv_consumer_state.advance() + return kv_consumer_state + + @cute.jit + def mma_one_n_block( + self, + smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + n_block: Int32, + mma_qk_fn: Callable, + mma_pv_fn: Callable, + tiled_mma_pv_rs: cute.TiledMma, + pipeline_k: cutlass.pipeline.PipelineAsync, + pipeline_v: cutlass.pipeline.PipelineAsync, + acc_O: cute.Tensor, + tOrP: cute.Tensor, + smem_copy_params: SimpleNamespace, + softmax: Softmax, + seqlen: SeqlenInfoQK, + score_mod_fn: Optional[Callable] = None, + mask_fn: Optional[Callable] = None, + is_first_n_block: cutlass.Constexpr = False, + check_inf: cutlass.Constexpr = True, + ): + pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) + # S = Q @ K.T + acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1) + self.warp_scheduler_barrier_arrive() + warpgroup.wait_group(0) + pipeline_k.consumer_release(smem_pipe_read) + + # handle score mods and masking + if const_expr(score_mod_fn is not None): + score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) + if const_expr(mask_fn is not None): + mask_fn(acc_S=acc_S, n_block=n_block) + + row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) + # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) + tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) + tOrP_cur = ( + tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + ) + # tOrP.store(tOrP_acc.load().to(self.dtype)) + # the "to(self.dtype)" conversion fails to vectorize for block sizes other + # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of + # 2 elements. So we just call ptx directly. + utils.cvt_f16(tOrP_acc, tOrP_cur) + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) + cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) + softmax.rescale_O(acc_O, row_scale) + if const_expr(not self.mma_pv_is_rs): + # Fence and barrier to make sure smem store is visible to WGMMA + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV + pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read)) + self.warp_scheduler_barrier_sync() + # O += P @ V + mma_pv_fn(B_idx=smem_pipe_read.index, wg_wait=0) + pipeline_v.consumer_release(smem_pipe_read) + smem_pipe_read.advance() + return smem_pipe_read + + @cute.jit + def mma_one_n_block_intrawg_overlap( + self, + smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + n_block: Int32, + mma_qk_fn: Callable, + mma_pv_fn: Callable, + tiled_mma_pv_rs: cute.TiledMma, + pipeline_k: cutlass.pipeline.PipelineAsync, + pipeline_v: cutlass.pipeline.PipelineAsync, + acc_O: cute.Tensor, + tOrP: cute.Tensor, + smem_copy_params: SimpleNamespace, + softmax: Softmax, + seqlen: SeqlenInfoQK, + score_mod_fn: Optional[Callable] = None, + mask_fn: Optional[Callable] = None, + check_inf: cutlass.Constexpr = True, + ): + smem_pipe_read_v = smem_pipe_read.clone() + smem_pipe_read.advance() + pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) + self.warp_scheduler_barrier_sync() + # S = Q @ K.T + acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1) + pipeline_v.consumer_wait(smem_pipe_read_v, pipeline_v.consumer_try_wait(smem_pipe_read_v)) + # O += P @ V + mma_pv_fn(B_idx=smem_pipe_read_v.index, wg_wait=-1) + self.warp_scheduler_barrier_arrive() + warpgroup.wait_group(1) + pipeline_k.consumer_release(smem_pipe_read) + + # handle score mods and masking + if const_expr(score_mod_fn is not None): + score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) + if const_expr(mask_fn is not None): + mask_fn(acc_S=acc_S, n_block=n_block) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) + + row_scale = softmax.online_softmax(acc_S, check_inf=check_inf) + warpgroup.wait_group(0) + pipeline_v.consumer_release(smem_pipe_read_v) + tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) + tOrP_cur = ( + tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + ) + # tOrP_cur.store(tOrP_acc.load().to(self.dtype)) + # the "to(self.dtype)" conversion fails to vectorize for block sizes other + # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of + # 2 elements. So we just call ptx directly. + utils.cvt_f16(tOrP_acc, tOrP_cur) + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) + cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) + softmax.rescale_O(acc_O, row_scale) + if const_expr(not self.mma_pv_is_rs): + # Fence and barrier to make sure smem store is visible to WGMMA + cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta) + cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV + return smem_pipe_read + + @cute.jit + def mma_init(self): + warp_group_idx = utils.canonical_warp_group_idx(sync=False) + if const_expr(self.use_scheduler_barrier): + if warp_group_idx == 1: + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1), + number_of_threads=2 * self.num_threads_per_warp_group, + ) + + @cute.jit + def apply_score_mod( + self, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + acc_S, + n_block, + softmax_scale, + seqlen, + aux_tensors: Optional[list] = None, + fastdiv_mods=None, + ): + # Prepare index tensor + cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) + cS = cute.domain_offset((m_block * self.tile_m, n_block * self.tile_n), cS) + tScS = thr_mma_qk.partition_C(cS) + + apply_score_mod_inner( + acc_S, + tScS, + self.score_mod, + batch_idx, + head_idx, + softmax_scale, + self.vec_size, + self.qk_acc_dtype, + aux_tensors, + fastdiv_mods, + seqlen_info=seqlen, + constant_q_idx=None, + qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + + def warp_scheduler_barrier_sync(self): + if const_expr(self.use_scheduler_barrier): + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + - 1 + + utils.canonical_warp_group_idx(sync=False), + number_of_threads=2 * self.num_threads_per_warp_group, + ) + + def warp_scheduler_barrier_arrive(self): + if const_expr(self.use_scheduler_barrier): + assert self.num_mma_warp_groups in [2, 3] + cur_wg = utils.canonical_warp_group_idx(sync=False) - 1 + if const_expr(self.num_mma_warp_groups == 2): + next_wg = 1 - cur_wg + else: + t = cur_wg + 1 + next_wg = t % self.num_mma_warp_groups + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, + number_of_threads=2 * self.num_threads_per_warp_group, + ) diff --git a/flash_attn/cute/flash_fwd_combine.py b/flash_attn/cute/flash_fwd_combine.py new file mode 100644 index 00000000000..f97e127175d --- /dev/null +++ b/flash_attn/cute/flash_fwd_combine.py @@ -0,0 +1,704 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_combine_kernel.h +# from Cutlass C++ to Cute-DSL. +import math +import operator +from typing import Type, Optional +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync +from cutlass import Float32, Int32, const_expr + +from flash_attn.cute import utils +from flash_attn.cute.seqlen_info import SeqlenInfo +from cutlass.cute import FastDivmodDivisor + + +class FlashAttentionForwardCombine: + def __init__( + self, + dtype: Type[cutlass.Numeric], + dtype_partial: Type[cutlass.Numeric], + head_dim: int, + m_block_size: int = 8, + k_block_size: int = 64, + log_max_splits: int = 4, + num_threads: int = 256, + stages: int = 4, + ): + """ + Forward combine kernel for split attention computation. + + :param dtype: output data type + :param dtype_partial: partial accumulation data type + :param head_dim: head dimension + :param m_block_size: m block size + :param k_block_size: k block size + :param log_max_splits: log2 of maximum splits + :param num_threads: number of threads + :param varlen: whether using variable length sequences + :param stages: number of pipeline stages + """ + self.dtype = dtype + self.dtype_partial = dtype_partial + self.head_dim = head_dim + self.m_block_size = m_block_size + self.k_block_size = k_block_size + self.max_splits = 1 << log_max_splits + self.num_threads = num_threads + self.is_even_k = head_dim % k_block_size == 0 + self.stages = stages + + @staticmethod + def can_implement( + dtype, + dtype_partial, + head_dim, + m_block_size, + k_block_size, + log_max_splits, + num_threads, + ) -> bool: + """Check if the kernel can be implemented with the given parameters.""" + if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]: + return False + if dtype_partial not in [cutlass.Float16, cutlass.BFloat16, Float32]: + return False + if head_dim % 8 != 0: + return False + if num_threads % 32 != 0: + return False + if m_block_size % 8 != 0: + return False + max_splits = 1 << log_max_splits + if max_splits > 256: + return False + if (m_block_size * max_splits) % num_threads != 0: + return False + return True + + def _setup_attributes(self): + # GMEM copy setup for O partial + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.dtype_partial.width + assert self.k_block_size % async_copy_elems == 0 + + k_block_gmem = ( + 128 if self.k_block_size % 128 == 0 else (64 if self.k_block_size % 64 == 0 else 32) + ) + gmem_threads_per_row = k_block_gmem // async_copy_elems + assert self.num_threads % gmem_threads_per_row == 0 + + # Async copy atom for O partial load + atom_async_copy_partial = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + self.dtype_partial, + num_bits_per_copy=universal_copy_bits, + ) + tOpartial_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + vOpartial_layout = cute.make_layout((1, async_copy_elems)) # 4 vals per load + self.gmem_tiled_copy_O_partial = cute.make_tiled_copy_tv( + atom_async_copy_partial, tOpartial_layout, vOpartial_layout + ) + + # GMEM copy setup for final O (use universal copy for store) + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=async_copy_elems * self.dtype.width, + ) + self.gmem_tiled_copy_O = cute.make_tiled_copy_tv( + atom_universal_copy, + tOpartial_layout, + vOpartial_layout, # 4 vals per store + ) + + # LSE copy setup with async copy (alignment = 1) + lse_copy_bits = Float32.width # 1 element per copy, width is in bits + m_block_smem = ( + 128 + if self.m_block_size % 128 == 0 + else ( + 64 + if self.m_block_size % 64 == 0 + else ( + 32 + if self.m_block_size % 32 == 0 + else (16 if self.m_block_size % 16 == 0 else 8) + ) + ) + ) + gmem_threads_per_row_lse = m_block_smem + assert self.num_threads % gmem_threads_per_row_lse == 0 + + # Async copy atom for LSE load + atom_async_copy_lse = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), + Float32, + num_bits_per_copy=lse_copy_bits, + ) + tLSE_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row_lse, gmem_threads_per_row_lse), + order=(1, 0), + ) + vLSE_layout = cute.make_layout(1) + self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv( + atom_async_copy_lse, tLSE_layout, vLSE_layout + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory + # /////////////////////////////////////////////////////////////////////////////// + + # Shared memory to register copy for LSE + self.smem_threads_per_col_lse = self.num_threads // m_block_smem + assert 32 % self.smem_threads_per_col_lse == 0 # Must divide warp size + + s2r_layout_atom_lse = cute.make_ordered_layout( + (self.smem_threads_per_col_lse, self.num_threads // self.smem_threads_per_col_lse), + order=(0, 1), + ) + self.s2r_tiled_copy_LSE = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32), + s2r_layout_atom_lse, + cute.make_layout(1), + ) + + # LSE shared memory layout with swizzling to avoid bank conflicts + # This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts + if const_expr(m_block_smem == 8): + smem_lse_swizzle = cute.make_swizzle(5, 0, 5) + elif const_expr(m_block_smem == 16): + smem_lse_swizzle = cute.make_swizzle(4, 0, 4) + else: + smem_lse_swizzle = cute.make_swizzle(3, 2, 3) + smem_layout_atom_lse = cute.make_composed_layout( + smem_lse_swizzle, 0, cute.make_ordered_layout((8, m_block_smem), order=(1, 0)) + ) + self.smem_layout_lse = cute.tile_to_shape( + smem_layout_atom_lse, (self.max_splits, self.m_block_size), (0, 1) + ) + + # O partial shared memory layout (simple layout for pipeline stages) + self.smem_layout_o = cute.make_ordered_layout( + (self.m_block_size, self.k_block_size, self.stages), order=(1, 0, 2) + ) + + @cute.jit + def __call__( + self, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor] = None, + cu_seqlens: Optional[cute.Tensor] = None, + seqused: Optional[cute.Tensor] = None, + num_splits_dynamic_ptr: Optional[cute.Tensor] = None, + semaphore_to_reset: Optional[cute.Tensor] = None, + stream: cuda.CUstream = None, + ): + # Type checking + if const_expr(not (mO_partial.element_type == self.dtype_partial)): + raise TypeError("O partial tensor must match dtype_partial") + if const_expr(not (mO.element_type == self.dtype)): + raise TypeError("O tensor must match dtype") + if const_expr(mLSE_partial.element_type not in [Float32]): + raise TypeError("LSE partial tensor must be Float32") + if const_expr(mLSE is not None and mLSE.element_type not in [Float32]): + raise TypeError("LSE tensor must be Float32") + + # Shape validation - input tensors are in user format, need to be converted to kernel format + if const_expr(len(mO_partial.shape) not in [4, 5]): + raise ValueError( + "O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)" + ) + if const_expr(len(mLSE_partial.shape) not in [3, 4]): + raise ValueError( + "LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)" + ) + if const_expr(len(mO.shape) not in [3, 4]): + raise ValueError( + "O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)" + ) + if const_expr(mLSE is not None and len(mLSE.shape) not in [2, 3]): + raise ValueError( + "LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)" + ) + + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + mO_partial, mO = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + for t in (mO_partial, mO) + ] + # (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b) + # or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h) + O_partial_layout_transpose = ( + [2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2] + ) + # (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h) + mO_partial = cute.make_tensor( + mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose) + ) + O_layout_transpose = [1, 3, 2, 0] if const_expr(cu_seqlens is None) else [0, 2, 1] + mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose)) + # (num_splits, b, seqlen, h) -> (seqlen, num_splits, h, b) + # or (num_splits, total_q, h) -> (total_q, num_splits, h) + LSE_partial_layout_transpose = [2, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 0, 2] + mLSE_partial = cute.make_tensor( + mLSE_partial.iterator, + cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose), + ) + # (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h) + LSE_layout_transpose = [1, 2, 0] if const_expr(cu_seqlens is None) else [0, 1] + mLSE = ( + cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) + if mLSE is not None + else None + ) + + # Determine if we have variable length sequences + varlen = const_expr(cu_seqlens is not None or seqused is not None) + + self._setup_attributes() + + @cute.struct + class SharedStorage: + sLSE: cute.struct.Align[ + cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128 + ] + sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.m_block_size], 128] + sO: cute.struct.Align[ + cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128 + ] + + smem_size = SharedStorage.size_in_bytes() + + # Grid dimensions: (ceil_div(seqlen, m_block), ceil_div(head_dim, k_block), num_head * batch) + seqlen = mO_partial.shape[0] + num_head = mO_partial.shape[3] + batch_size = ( + mO_partial.shape[4] + if const_expr(cu_seqlens is None) + else Int32(cu_seqlens.shape[0] - 1) + ) + + # Create FastDivmodDivisor objects for efficient division + seqlen_divmod = FastDivmodDivisor(seqlen) + head_divmod = FastDivmodDivisor(num_head) + + grid_dim = ( + cute.ceil_div(seqlen * num_head, self.m_block_size), + cute.ceil_div(self.head_dim, self.k_block_size), + batch_size, + ) + + self.kernel( + mO_partial, + mLSE_partial, + mO, + mLSE, + cu_seqlens, + seqused, + num_splits_dynamic_ptr, + semaphore_to_reset, + SharedStorage, + self.smem_layout_lse, + self.smem_layout_o, + self.gmem_tiled_copy_O_partial, + self.gmem_tiled_copy_O, + self.gmem_tiled_copy_LSE, + self.s2r_tiled_copy_LSE, + seqlen_divmod, + head_divmod, + varlen, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + smem=smem_size, + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + cu_seqlens: Optional[cute.Tensor], + seqused: Optional[cute.Tensor], + num_splits_dynamic_ptr: Optional[cute.Tensor], + semaphore_to_reset: Optional[cute.Tensor], + SharedStorage: cutlass.Constexpr, + smem_layout_lse: cute.Layout | cute.ComposedLayout, + smem_layout_o: cute.Layout, + gmem_tiled_copy_O_partial: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + gmem_tiled_copy_LSE: cute.TiledCopy, + s2r_tiled_copy_LSE: cute.TiledCopy, + seqlen_divmod: FastDivmodDivisor, + head_divmod: FastDivmodDivisor, + varlen: cutlass.Constexpr[bool], + ): + # Thread and block indices + tidx, _, _ = cute.arch.thread_idx() + m_block, k_block, batch_idx = cute.arch.block_idx() + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sLSE = storage.sLSE.get_tensor(smem_layout_lse) + sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.m_block_size,)) + sO = storage.sO.get_tensor(smem_layout_o) + + # Handle semaphore reset + if const_expr(semaphore_to_reset is not None): + if ( + tidx == 0 + and m_block == cute.arch.grid_dim()[0] - 1 + and k_block == cute.arch.grid_dim()[1] - 1 + and batch_idx == cute.arch.grid_dim()[2] - 1 + ): + semaphore_to_reset[0] = 0 + + # Get number of splits + num_splits = ( + num_splits_dynamic_ptr[batch_idx] + if const_expr(num_splits_dynamic_ptr is not None) + else mLSE_partial.shape[1] + ) + # Handle variable length sequences using SeqlenInfo + seqlen_info = SeqlenInfo.create( + batch_idx=batch_idx, + seqlen_static=mO_partial.shape[0], + cu_seqlens=cu_seqlens, + seqused=seqused, + ) + seqlen, offset = seqlen_info.seqlen, seqlen_info.offset + + # Extract number of heads (head index will be determined dynamically) + num_head = mO_partial.shape[3] + max_idx = seqlen * num_head + + # Early exit for single split if dynamic + if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and ( + const_expr(not varlen) or m_block * self.m_block_size < max_idx + ): + # =============================== + # Step 1: Load LSE_partial from gmem to shared memory + # =============================== + + if const_expr(cu_seqlens is None): + # mLSE_partial_cur = mLSE_partial[None, None, None, batch_idx] + mLSE_partial_cur = utils.coord_offset_i64(mLSE_partial, batch_idx, dim=3) + else: + # mLSE_partial_cur = cute.domain_offset((offset, 0, 0), mLSE_partial) + mLSE_partial_cur = utils.domain_offset_i64((offset, 0, 0), mLSE_partial) + mLSE_partial_copy = cute.tiled_divide(mLSE_partial_cur, (1,)) + + gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx) + tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE) + + # Create identity tensor for coordinate tracking + cLSE = cute.make_identity_tensor((self.max_splits, self.m_block_size)) + tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE) + + # Load LSE partial values + for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True): + mi = tLSEcLSE[0, 0, m][1] # Get m coordinate + idx = m_block * self.m_block_size + mi + if idx < max_idx: + # Calculate actual sequence position and head using FastDivmodDivisor + if const_expr(not varlen): + head_idx, m_idx = divmod(idx, seqlen_divmod) + else: + head_idx = idx // seqlen + m_idx = idx - head_idx * seqlen + mLSE_partial_cur_copy = mLSE_partial_copy[None, m_idx, None, head_idx] + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + si = tLSEcLSE[0, s, 0][0] # Get split coordinate + if si < num_splits: + cute.copy( + gmem_thr_copy_LSE, + mLSE_partial_cur_copy[None, si], + tLSEsLSE[None, s, m], + ) + else: + tLSEsLSE[None, s, m].fill(-Float32.inf) + # Don't need to zero out the rest of the LSEs, as we will not write the output to gmem + cute.arch.cp_async_commit_group() + + # =============================== + # Step 2: Load O_partial for pipeline stages + # =============================== + + gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx) + cO = cute.make_identity_tensor((self.m_block_size, self.k_block_size)) + tOcO = gmem_thr_copy_O_partial.partition_D(cO) + tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO) + if const_expr(cu_seqlens is None): + # mO_partial_cur = mO_partial[None, None, None, None, batch_idx] + mO_partial_cur = utils.coord_offset_i64(mO_partial, batch_idx, dim=4) + else: + # mO_partial_cur = cute.domain_offset((offset, 0, 0, 0), mO_partial) + mO_partial_cur = utils.domain_offset_i64((offset, 0, 0, 0), mO_partial) + + # Precompute these values to avoid recomputing them in the loop + num_rows = const_expr(cute.size(tOcO, mode=[1])) + tOmidx = cute.make_fragment(num_rows, cutlass.Int32) + tOhidx = cute.make_fragment(num_rows, cutlass.Int32) + tOrOptr = cute.make_fragment(num_rows, cutlass.Int64) + for m in cutlass.range(num_rows, unroll_full=True): + mi = tOcO[0, m, 0][0] # m coordinate + idx = m_block * self.m_block_size + mi + if const_expr(not varlen): + tOhidx[m], tOmidx[m] = divmod(idx, seqlen_divmod) + else: + tOhidx[m] = idx // seqlen + tOmidx[m] = idx - tOhidx[m] * seqlen + tOrOptr[m] = utils.elem_pointer_i64( + mO_partial_cur, (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m]) + ).toint() + if idx >= max_idx: + tOhidx[m] = -1 + + tOpO = cute.make_fragment(cute.size(tOcO, [2]), cutlass.Boolean) + if const_expr(not self.is_even_k): + for k in cutlass.range(cute.size(tOpO), unroll_full=True): + tOpO[k] = tOcO[0, 0, k][1] < mO_partial.shape[1] - k_block * self.k_block_size + # if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO) + + load_O_partial = partial( + self.load_O_partial, + gmem_tiled_copy_O_partial, + tOrOptr, + tOsO_partial, + tOhidx, + tOpO, + tOcO, + mO_partial_cur.layout, + ) + + # Load first few stages of O_partial + for stage in cutlass.range(self.stages - 1, unroll_full=True): + if stage < num_splits: + load_O_partial(stage, stage) + cute.arch.cp_async_commit_group() + + # =============================== + # Step 3: Load and transpose LSE from smem to registers + # =============================== + + # Wait for LSE and initial O partial stages to complete + cute.arch.cp_async_wait_group(self.stages - 1) + cute.arch.sync_threads() + # if cute.arch.thread_idx()[0] == 0: + # # cute.print_tensor(sLSE) + # for i in range(64): + # cute.printf("sLSE[%d, 0] = %f", i, sLSE[i, 0]) + # cute.arch.sync_threads() + + s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx) + ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE) + ts2rrLSE = cute.make_fragment_like(ts2rsLSE) + cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE) + + # =============================== + # Step 4: Compute final LSE along split dimension + # =============================== + + lse_sum = cute.make_fragment(cute.size(ts2rrLSE, mode=[2]), Float32) + ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE) + # We compute the max valid split for each row to short-circuit the computation later + max_valid_split = cute.make_fragment(cute.size(ts2rrLSE, mode=[2]), Int32) + assert cute.size(ts2rrLSE, mode=[0]) == 1 + # Compute max, scales, and final LSE for each row + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + # Find max LSE value across splits + threads_per_col = const_expr(self.smem_threads_per_col_lse) + lse_max = utils.warp_reduce( + ts2rrLSE[None, None, m] + .load() + .reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0), + op=cute.arch.fmax, + width=threads_per_col, + ) + # if cute.arch.thread_idx()[0] == 0: cute.printf(lse_max) + # Find max valid split index + max_valid_idx = -1 + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + if ts2rrLSE[0, s, m] != -Float32.inf: + max_valid_idx = ts2rcLSE[0, s, 0][0] # Get split coordinate + # if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx) + max_valid_split[m] = utils.warp_reduce(max_valid_idx, max, width=threads_per_col) + # Compute exp scales and sum + lse_max_cur = ( + 0.0 if lse_max == -Float32.inf else lse_max + ) # In case all local LSEs are -inf + LOG2_E = math.log2(math.e) + lse_sum_cur = 0.0 + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + scale = utils.exp2f(ts2rrLSE[0, s, m] * LOG2_E - (lse_max_cur * LOG2_E)) + lse_sum_cur += scale + ts2rrLSE[0, s, m] = scale # Store scale for later use + lse_sum_cur = utils.warp_reduce(lse_sum_cur, operator.add, width=threads_per_col) + lse_sum[m] = utils.logf(lse_sum_cur) + lse_max + # Normalize scales + inv_sum = ( + 0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur + ) + ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum) + # Store the scales exp(lse - lse_logsum) back to smem + cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE) + + # Store max valid split to smem + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes + mi = ts2rcLSE[0, 0, m][1] + if mi < self.m_block_size: + sMaxValidSplit[mi] = max_valid_split[m] + + # =============================== + # Step 5: Store final LSE to gmem + # =============================== + + if const_expr(mLSE is not None): + if const_expr(cu_seqlens is None): + # mLSE_cur = mLSE[None, None, batch_idx] + mLSE_cur = utils.coord_offset_i64(mLSE, batch_idx, dim=2) + else: + # mLSE_cur = cute.domain_offset((offset, 0), mLSE) + mLSE_cur = utils.domain_offset_i64((offset, 0), mLSE) + if k_block == 0: # Only first k_block writes LSE when mLSE is provided + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes + mi = ts2rcLSE[0, 0, m][1] + idx = m_block * self.m_block_size + mi + if idx < max_idx: + if const_expr(not varlen): + head_idx, m_idx = divmod(idx, seqlen_divmod) + else: + head_idx = idx // seqlen + m_idx = idx - head_idx * seqlen + mLSE_cur[m_idx, head_idx] = lse_sum[m] + + # =============================== + # Step 6: Read O_partial and accumulate final O + # =============================== + + cute.arch.sync_threads() + + # Get max valid split for this thread + thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]] + for m in cutlass.range(1, cute.size(tOcO, mode=[1])): + thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[tOcO[0, m, 0][0]]) + + tOrO_partial = cute.make_fragment_like(tOsO_partial[None, None, None, 0]) + tOrO = cute.make_fragment_like(tOrO_partial, Float32) + tOrO.fill(0.0) + + stage_load = self.stages - 1 + stage_compute = 0 + + # Main accumulation loop + for s in cutlass.range(thr_max_valid_split + 1, unroll=4): + # Get scales for this split + scale = cute.make_fragment(num_rows, Float32) + for m in cutlass.range(num_rows, unroll_full=True): + scale[m] = sLSE[s, tOcO[0, m, 0][0]] # Get scale from smem + + # Load next stage if needed + split_to_load = s + self.stages - 1 + if split_to_load <= thr_max_valid_split: + load_O_partial(split_to_load, stage_load) + cute.arch.cp_async_commit_group() + stage_load = 0 if stage_load == self.stages - 1 else stage_load + 1 + + # Wait for the current stage to be ready + cute.arch.cp_async_wait_group(self.stages - 1) + # We don't need __syncthreads() because each thread is just reading its own data from smem + # Copy from smem to registers + cute.autovec_copy(tOsO_partial[None, None, None, stage_compute], tOrO_partial) + stage_compute = 0 if stage_compute == self.stages - 1 else stage_compute + 1 + + # Accumulate scaled partial results + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= 0 and scale[m] > 0.0: + tOrO[None, m, None].store( + tOrO[None, m, None].load() + + scale[m] * tOrO_partial[None, m, None].load().to(Float32) + ) + + # =============================== + # Step 7: Write final O to gmem + # =============================== + + rO = cute.make_fragment_like(tOrO, self.dtype) + rO.store(tOrO.load().to(self.dtype)) + if const_expr(cu_seqlens is None): + # mO_cur = mO[None, None, None, batch_idx] + mO_cur = utils.coord_offset_i64(mO, batch_idx, dim=3) + else: + # mO_cur = cute.domain_offset((offset, 0, 0), mO) + mO_cur = utils.domain_offset_i64((offset, 0, 0), mO) + mO_cur = utils.domain_offset_aligned((0, k_block * self.k_block_size, 0), mO_cur) + elems_per_store = const_expr(cute.size(gmem_tiled_copy_O.layout_tv_tiled[1])) + # mO_cur_copy = cute.tiled_divide(mO_cur, (1, elems_per_store,)) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + # Write final results + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= 0: + mO_cur_copy = cute.tiled_divide( + mO_cur[tOmidx[m], None, tOhidx[m]], (elems_per_store,) + ) + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + k_idx = tOcO[0, 0, k][1] // elems_per_store + if const_expr(self.is_even_k) or tOpO[k]: + cute.copy(gmem_thr_copy_O, rO[None, m, k], mO_cur_copy[None, k_idx]) + + @cute.jit + def load_O_partial( + self, + gmem_tiled_copy_O_partial: cute.TiledCopy, + tOrOptr: cute.Tensor, + tOsO_partial: cute.Tensor, + tOhidx: cute.Tensor, + tOpO: cute.Tensor, + tOcO: cute.Tensor, + mO_cur_partial_layout: cute.Layout, + split: Int32, + stage: Int32, + ) -> None: + elems_per_load = const_expr(cute.size(gmem_tiled_copy_O_partial.layout_tv_tiled[1])) + tOsO_partial_cur = tOsO_partial[None, None, None, stage] + for m in cutlass.range(cute.size(tOcO, [1]), unroll_full=True): + if tOhidx[m] >= 0: + o_gmem_ptr = cute.make_ptr( + tOsO_partial.element_type, tOrOptr[m], cute.AddressSpace.gmem, assumed_align=16 + ) + mO_partial_cur = cute.make_tensor( + o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0)) + ) + mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,)) + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + k_idx = tOcO[0, 0, k][1] // elems_per_load + if const_expr(self.is_even_k) or tOpO[k]: + cute.copy( + gmem_tiled_copy_O_partial, + # mO_partial_cur_copy[None, k_idx, split], + utils.coord_offset_i64(mO_partial_cur_copy, split, dim=2)[None, k_idx], + tOsO_partial_cur[None, m, k], + ) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py new file mode 100644 index 00000000000..cc81edaf84a --- /dev/null +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -0,0 +1,2740 @@ +# Supported features: +# - BF16 & FP16 dtype +# - noncausal & causal attention +# - MHA, GQA, MQA +# - hdim 64, 96, 128, (192, 128). +# - varlen +# - sliding window +# - split-kv +# Unsupported features that will be added later: +# - page size != 128 +# - more hdim (192, 256) +# Based on the cutlass example and cute-dsl example: +# https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha +# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py + +import enum +import math +from typing import Type, Tuple, Callable, Optional, Literal +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute.nvgpu import cpasync +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils_basic + +from flash_attn.cute.paged_kv import PagedKVManager +import flash_attn.cute.utils as utils +from flash_attn.cute import copy_utils +import flash_attn.cute.pipeline as pipeline +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( + get_total_block_count, + produce_block_sparse_loads_sm100, + softmax_block_sparse_sm100, + handle_block_sparse_empty_tile_correction_sm100, +) +from flash_attn.cute.pack_gqa import PackGQA +from flash_attn.cute import mma_sm100_desc as sm100_desc +from flash_attn.cute import blackwell_helpers as sm100_utils +from cutlass.cute import FastDivmodDivisor +from flash_attn.cute.tile_scheduler import ( + TileSchedulerArguments, + SingleTileScheduler, + StaticPersistentTileScheduler, + SingleTileLPTScheduler, + SingleTileVarlenScheduler, + ParamsBase, +) + + +class NamedBarrierFwd(enum.IntEnum): + Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() +# WarpSchedulerWG1 = enum.auto() +# WarpSchedulerWG2 = enum.auto() +# WarpSchedulerWG3 = enum.auto() +# PFull = enum.auto() +# PEmpty = enum.auto() + + +class FlashAttentionForwardSm100: + arch = 100 + + def __init__( + self, + # dtype: Type[cutlass.Numeric], + head_dim: int, + head_dim_v: Optional[int] = None, + qhead_per_kvhead: cutlass.Constexpr[int] = 1, + is_causal: bool = False, + is_local: bool = False, + is_split_kv: bool = False, + pack_gqa: bool = False, + m_block_size: int = 128, + n_block_size: int = 128, + q_stage: cutlass.Constexpr[int] = 2, + is_persistent: bool = True, + score_mod: cutlass.Constexpr | None = None, + mask_mod: cutlass.Constexpr | None = None, + has_aux_tensors: cutlass.Constexpr = False, + paged_kv_non_tma: bool = False, + is_varlen_q: bool = False, + ): + self.use_tma_KV = not paged_kv_non_tma + # self.dtype = dtype + # padding head_dim to a multiple of 16 as k_block_size + hdim_multiple_of = 16 + self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + head_dim_v = head_dim_v if head_dim_v is not None else head_dim + self.same_hdim_kv = head_dim == head_dim_v + self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) + self.same_hdim_kv_padded = self.head_dim_padded == self.head_dim_v_padded + self.check_hdim_oob = head_dim != self.head_dim_padded + self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded + self.m_block_size = m_block_size + self.n_block_size = n_block_size + self.q_stage = q_stage + assert self.q_stage in [1, 2] + + # 2 Q tile per CTA + self.cta_tiler = (self.q_stage * m_block_size, n_block_size, self.head_dim_padded) + self.mma_tiler_qk = (m_block_size, n_block_size, self.head_dim_padded) + self.mma_tiler_pv = (m_block_size, self.head_dim_v_padded, n_block_size) + self.qk_acc_dtype = Float32 + self.pv_acc_dtype = Float32 + self.cluster_shape_mn = (1, 1) + self.is_persistent = is_persistent + self.is_causal = is_causal + self.is_local = is_local + self.is_varlen_q = is_varlen_q + self.use_correction_warps_for_epi = is_varlen_q + self.qhead_per_kvhead = qhead_per_kvhead + self.is_split_kv = is_split_kv + self.pack_gqa = pack_gqa + if pack_gqa: + assert m_block_size % self.qhead_per_kvhead == 0, ( + "For PackGQA, m_block_size must be divisible by qhead_per_kvhead" + ) + assert not (self.is_split_kv and self.head_dim_v_padded >= 192), ( + "SplitKV is not supported for hdim >= 192" + ) + self.score_mod = score_mod + self.mask_mod = mask_mod + if cutlass.const_expr(has_aux_tensors): + self.vec_size: cutlass.Constexpr = 1 + else: + self.vec_size: cutlass.Constexpr = 2 + # Does S1 need to wait for S0 to finish + # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) + self.s0_s1_barrier = False + self.overlap_sO_sQ = ( + (self.head_dim_padded == 192 and self.head_dim_v_padded >= 64) or + (self.head_dim_v_padded >= 128 and self.is_split_kv) + ) + if self.overlap_sO_sQ: + self.is_persistent = False + + assert self.use_tma_KV or not (self.check_hdim_oob or self.check_hdim_v_oob), ( + "Paged KV does not support irregular head dim" + ) + + self.softmax0_warp_ids = (0, 1, 2, 3) + self.softmax1_warp_ids = (4, 5, 6, 7) + self.correction_warp_ids = (8, 9, 10, 11) + self.mma_warp_id = 12 + self.epilogue_warp_ids = (13,) + self.load_warp_ids = (14,) + self.empty_warp_ids = (15,) + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + self.threads_per_cta = cute.arch.WARP_SIZE * len( + ( + *self.softmax0_warp_ids, + *self.softmax1_warp_ids, + *self.correction_warp_ids, + self.mma_warp_id, + *self.load_warp_ids, + *self.epilogue_warp_ids, + *self.empty_warp_ids, + ) + ) + + if self.q_stage == 1: + if not self.use_tma_KV: + self.empty_warp_ids = self.empty_warp_ids + self.load_warp_ids + self.load_warp_ids = self.softmax1_warp_ids + else: + self.empty_warp_ids = self.empty_warp_ids + self.softmax1_warp_ids + self.softmax1_warp_ids = () + elif not self.use_tma_KV: + self.load_warp_ids = (14, 15) + self.empty_warp_ids = () + + if self.use_correction_warps_for_epi: + self.empty_warp_ids = self.empty_warp_ids + self.epilogue_warp_ids + self.epilogue_warp_ids = self.correction_warp_ids + elif self.is_varlen_q: # fallback + self.epilogue_warp_ids = (13, 14) + + self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128 + self.tmem_o_offset = [ + self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded + for i in range(self.q_stage) + ] # e.g., 256, 384 + self.tmem_total = self.tmem_o_offset[-1] + self.head_dim_v_padded + assert self.tmem_total <= SM100_TMEM_CAPACITY_COLUMNS + self.tmem_s_to_p_offset = self.n_block_size // 2 + self.tmem_p_offset = [ + self.tmem_s_offset[i] + self.tmem_s_to_p_offset for i in range(2) + ] # 0, 128 + + # vec buffer for row_max & row_sum + self.tmem_vec_offset = self.tmem_s_offset + + if self.head_dim_padded < 96: + self.num_regs_softmax = 200 if not paged_kv_non_tma else 184 + self.num_regs_correction = 64 + self.num_regs_other = 48 if not paged_kv_non_tma else 80 + else: + # self.num_regs_softmax = 192 if self.is_causal or self.is_local else 184 + self.num_regs_softmax = 200 if not paged_kv_non_tma else 184 + # self.num_regs_softmax = 176 + # self.num_regs_correction = 96 + # self.num_regs_correction = 80 + # self.num_regs_correction = 64 if self.is_causal or self.is_local else 80 + self.num_regs_correction = 64 + # self.num_regs_other = 32 + # self.num_regs_other = 64 + # self.num_regs_other = 80 + self.num_regs_other = 48 if not paged_kv_non_tma else 80 + # self.num_regs_other = 96 if self.is_causal or self.is_local else 80 + # self.num_regs_other = 64 if self.is_causal or self.is_local else 80 + self.num_regs_empty = 24 + + self.buffer_align_bytes = 1024 + + def _setup_attributes(self): + """Set up configurations and parameters for the FMHA kernel operation. + + This method initializes and configures various attributes required for the + execution of the fused multi-head attention kernel, mainly about the pipeline stages: + + - Sets up staging parameters for Q, K, V inputs and accumulator data + - Configures pipeline stages for softmax, correction, and epilogue operations + """ + + self.kv_stage = 4 if self.q_dtype.width == 8 or self.q_stage == 1 else 3 + self.acc_stage = 1 + # For hdim 192,128, we don't have enough smem to store all 3 stages of KV: + # 128 x 192 x 2 bytes x 3 stages = 144KB, and we need 96KB for Q. + # Instead we store smem as [smem_large, smem_small, smem_large], where smem_large is + # 128 x 192 and smem_small is 128 x 128. We set the stride between the stages to be + # 128 * 160, so that indexing the 0th and 2nd stages will get the right address, + # but for the 1st stage we need to add or subtract (depending on phase) 128 x 64. + self.uneven_kv_smem = ( + self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and self.kv_stage == 3 + ) + self.uneven_kv_smem_offset = ( + self.m_block_size * (self.head_dim_padded - self.head_dim_v_padded) // 2 + if self.uneven_kv_smem + else 0 + ) + assert self.uneven_kv_smem_offset % 1024 == 0 + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table + mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table + mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + mLSE: Optional[cute.Tensor], + softmax_scale: Float32, + stream: cuda.CUstream, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq) + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, + learnable_sink: Optional[cute.Tensor] = None, + blocksparse_tensors: Optional[BlockSparseTensors] = None, + aux_tensors: Optional[list] = None, + ): + """Execute the Fused Multi-Head Attention operation on the provided tensors. + + This method prepares the input tensors for processing, validates their shapes and types, + configures the computation parameters, and launches the CUDA kernel. + + The method handles: + 1. Tensor layout transformations for specific memory access patterns + 2. Validation of tensor shapes and data types + 3. Initialization of hardware-specific parameters and memory layouts + 4. Configuration of TMA (Tensor Memory Access) operations + 5. Grid and work scheduling computation + 6. Kernel launch with appropriate parameters + """ + # setup static attributes before smem/grid/tma computation + self.q_dtype = mQ.element_type + self.k_dtype = mK.element_type + self.v_dtype = mV.element_type + self.o_dtype = mO.element_type + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: ( + *(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), + t.stride[-1], + ) + mQ, mK, mV, mO = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + for t in (mQ, mK, mV, mO) + ] + Q_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + mQ = cute.make_tensor(mQ.iterator, cute.select(mQ.layout, mode=Q_layout_transpose)) + # (s_k, d, h_k, b_k) or (total_k, d, h_k) if there's cu_seqlens_k or (page_size, d, h_k, num_pages) if there's page_table + KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] + mK, mV = [ + cute.make_tensor(t.iterator, cute.select(t.layout, mode=KV_layout_transpose)) + for t in (mK, mV) + ] + if const_expr(self.is_split_kv): + O_layout_transpose = [2, 4, 3, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 3, 2, 0] + LSE_layout_transpose = [3, 2, 1, 0] if const_expr(mCuSeqlensQ is None) else [2, 1, 0] + num_splits = mO.shape[0] + else: + O_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + num_splits = Int32(1) + mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose)) + mLSE = ( + cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) + if const_expr(mLSE is not None) + else None + ) + # (s, d, h, b) -> (d, s, h, b) + V_layout_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensK is None) else [1, 0, 2] + mV = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=V_layout_transpose)) + + self.q_major_mode = cutlass.utils.LayoutEnum.from_tensor(mQ).mma_major_mode() + self.k_major_mode = cutlass.utils.LayoutEnum.from_tensor(mK).mma_major_mode() + self.v_major_mode = cutlass.utils.LayoutEnum.from_tensor(mV).mma_major_mode() + self.o_layout = cutlass.utils.LayoutEnum.from_tensor(mO) + + if const_expr(self.q_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of mQ is not supported") + if const_expr(self.k_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of mK is not supported") + if const_expr(self.v_major_mode != tcgen05.OperandMajorMode.MN): + raise RuntimeError("The layout of mV is not supported") + + # check type consistency + if const_expr(self.q_dtype != self.k_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}") + if const_expr(self.q_dtype != self.v_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") + self._setup_attributes() + self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None + # This can be tuned + self.e2e_freq = 16 + if const_expr( + self.head_dim_padded > 64 and not self.is_causal and not self.is_local and self.pack_gqa + ): + self.e2e_freq = 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else 10 + + cta_group = tcgen05.CtaGroup.ONE + # the intermediate tensor p is from tmem & mK-major + p_source = tcgen05.OperandSource.TMEM + p_major_mode = tcgen05.OperandMajorMode.K + tiled_mma_qk = sm100_utils_basic.make_trivial_tiled_mma( + self.q_dtype, + self.q_major_mode, + self.k_major_mode, + self.qk_acc_dtype, + cta_group, + self.mma_tiler_qk[:2], + ) + tiled_mma_pv = sm100_utils_basic.make_trivial_tiled_mma( + self.v_dtype, + p_major_mode, + self.v_major_mode, + self.pv_acc_dtype, + cta_group, + self.mma_tiler_pv[:2], + p_source, + ) + + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (tiled_mma_qk.thr_id.shape,), + ) + + self.epi_tile = self.mma_tiler_pv[:2] + + sQ_layout = sm100_utils_basic.make_smem_layout_a( + tiled_mma_qk, + self.mma_tiler_qk, + self.q_dtype, + self.q_stage, + ) + sK_layout = sm100_utils_basic.make_smem_layout_b( + tiled_mma_qk, + self.mma_tiler_qk, + self.k_dtype, + self.kv_stage, + ) + tP_layout = sm100_utils_basic.make_smem_layout_a( + tiled_mma_pv, + self.mma_tiler_pv, + self.q_dtype, + self.acc_stage, + ) + sV_layout = sm100_utils_basic.make_smem_layout_b( + tiled_mma_pv, + self.mma_tiler_pv, + self.v_dtype, + self.kv_stage, + ) + sO_layout = sm100_utils_basic.make_smem_layout_epi( + self.o_dtype, + self.o_layout, + self.epi_tile, + self.q_stage, + ) + if const_expr(not self.same_hdim_kv_padded): + # sK and sV are using the same physical smem so we need to adjust the stride so that they line up + stride_sK = const_expr( + max(sK_layout.outer.stride[-1], 0) + ) # take max to turn tuple to Int32 + stride_sV = const_expr(max(sV_layout.outer.stride[-1], 0)) + stage_stride = const_expr( + max(stride_sK, stride_sV) + if not self.uneven_kv_smem + else (stride_sK + stride_sV) // 2 + ) + sK_layout = cute.make_composed_layout( + sK_layout.inner, + 0, + cute.make_layout( + (*sK_layout.outer.shape[:-1], self.kv_stage), + stride=(*sK_layout.outer.stride[:-1], stage_stride), + ), + ) + sV_layout = cute.make_composed_layout( + sV_layout.inner, + 0, + cute.make_layout( + (*sV_layout.outer.shape[:-1], self.kv_stage), + stride=(*sV_layout.outer.stride[:-1], stage_stride), + ), + ) + + if const_expr(self.pack_gqa): + shape_Q_packed = ( + (self.qhead_per_kvhead, mQ.shape[0]), + mQ.shape[1], + mK.shape[2], + *mQ.shape[3:], + ) + stride_Q_packed = ( + (mQ.stride[2], mQ.stride[0]), + mQ.stride[1], + mQ.stride[2] * self.qhead_per_kvhead, + *mQ.stride[3:], + ) + mQ = cute.make_tensor( + mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed) + ) + shape_O_packed = ( + (self.qhead_per_kvhead, mO.shape[0]), + mO.shape[1], + mK.shape[2], + *mO.shape[3:], + ) + stride_O_packed = ( + (mO.stride[2], mO.stride[0]), + mO.stride[1], + mO.stride[2] * self.qhead_per_kvhead, + *mO.stride[3:], + ) + mO = cute.make_tensor( + mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed) + ) + if const_expr(mLSE is not None): + shape_LSE_packed = ( + (self.qhead_per_kvhead, mLSE.shape[0]), + mK.shape[2], + *mLSE.shape[2:], + ) + stride_LSE_packed = ( + (mLSE.stride[1], mLSE.stride[0]), + mLSE.stride[1] * self.qhead_per_kvhead, + *mLSE.stride[2:], + ) + mLSE = cute.make_tensor( + mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed) + ) + + self.tma_copy_bytes = { + name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2])) + for name, mX, layout in [ + ("Q", mQ, sQ_layout), + ("K", mK, sK_layout), + ("V", mV, sV_layout), + ] + } + + # TMA load for Q + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + + tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + mQ, + cute.select(sQ_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, + tiled_mma_qk, + self.cluster_layout_vmnk.shape, + ) + + if const_expr(self.use_tma_KV): + # TMA load for K + tma_atom_K, mK = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + mK, + cute.select(sK_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, + tiled_mma_qk, + self.cluster_layout_vmnk.shape, + ) + # TMA load for V + tma_atom_V, mV = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + mV, + cute.select(sV_layout, mode=[0, 1, 2]), + self.mma_tiler_pv, + tiled_mma_pv, + self.cluster_layout_vmnk.shape, + ) + else: + tma_atom_K = None + tma_atom_V = None + + o_cta_v_layout = cute.composition(cute.make_identity_layout(mO.shape), self.epi_tile) + + self.num_epilogue_threads = cute.arch.WARP_SIZE * len(self.epilogue_warp_ids) + if const_expr(self.use_tma_O): + tma_atom_O, mO = cpasync.make_tiled_tma_atom( + tma_store_op, + mO, + cute.select(sO_layout, mode=[0, 1]), + o_cta_v_layout, + ) + gmem_tiled_copy_O = None + else: + tma_atom_O = None + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.o_dtype.width + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.o_dtype, + num_bits_per_copy=universal_copy_bits, + ) + tO_shape_dim_1 = sO_layout.outer.shape[1][0] // async_copy_elems + tO_layout = cute.make_ordered_layout( + (self.num_epilogue_threads // tO_shape_dim_1, tO_shape_dim_1), + order=(1, 0), + ) + # So that we don't have to check if we overshoot kBlockM when we store O + assert self.m_block_size % tO_layout.shape[0] == 0 + vO_layout = cute.make_layout((1, async_copy_elems)) + gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) + + if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): + TileScheduler = SingleTileVarlenScheduler + else: + if const_expr(self.is_causal or self.is_local): + TileScheduler = SingleTileLPTScheduler + else: + TileScheduler = ( + SingleTileScheduler + if const_expr(not self.is_persistent) + else StaticPersistentTileScheduler + ) + tile_sched_args = TileSchedulerArguments( + cute.ceil_div(cute.size(mQ.shape[0]), self.cta_tiler[0]), + cute.size(mQ.shape[2]), + cute.size(mQ.shape[3]) + if const_expr(mCuSeqlensQ is None) + else cute.size(mCuSeqlensQ.shape[0] - 1), + num_splits, + cute.size(mK.shape[0]) + if const_expr(mPageTable is None) + else mK.shape[0] * mPageTable.shape[1], + mQ.shape[1], + mV.shape[0], # Note that this is different from Sm90 since we transpose mV in Sm100 + total_q=cute.size(mQ.shape[0]) + if const_expr(mCuSeqlensQ is not None) + else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + tile_shape_mn=self.cta_tiler[:2], + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + element_size=self.k_dtype.width // 8, + is_persistent=self.is_persistent, + lpt=self.is_causal or self.is_local, + is_split_kv=self.is_split_kv, + ) + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + self.tile_scheduler_cls = TileScheduler + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + + self.mbar_load_q_full_offset = 0 + self.mbar_load_q_empty_offset = self.mbar_load_q_full_offset + self.q_stage + self.mbar_load_kv_full_offset = self.mbar_load_q_empty_offset + self.q_stage + self.mbar_load_kv_empty_offset = self.mbar_load_kv_full_offset + self.kv_stage + self.mbar_P_full_O_rescaled_offset = self.mbar_load_kv_empty_offset + self.kv_stage + self.mbar_S_full_offset = self.mbar_P_full_O_rescaled_offset + self.q_stage + self.mbar_O_full_offset = self.mbar_S_full_offset + self.q_stage + self.mbar_softmax_corr_full_offset = self.mbar_O_full_offset + self.q_stage + self.mbar_softmax_corr_empty_offset = self.mbar_softmax_corr_full_offset + self.q_stage + self.mbar_corr_epi_full_offset = self.mbar_softmax_corr_empty_offset + self.q_stage + self.mbar_corr_epi_empty_offset = self.mbar_corr_epi_full_offset + self.q_stage + self.mbar_s0_s1_sequence_offset = self.mbar_corr_epi_empty_offset + self.q_stage + self.mbar_tmem_dealloc_offset = self.mbar_s0_s1_sequence_offset + 8 + self.mbar_P_full_2_offset = self.mbar_tmem_dealloc_offset + 1 + self.mbar_total = self.mbar_P_full_2_offset + self.q_stage + + sO_size = cute.cosize(sO_layout) if const_expr(not self.overlap_sO_sQ) else 0 + sQ_size = ( + cute.cosize(sQ_layout) if const_expr(not self.overlap_sO_sQ) else + cutlass.max(cute.cosize(sQ_layout), cute.cosize(sO_layout) * self.o_dtype.width // self.q_dtype.width) + ) + + @cute.struct + class SharedStorage: + # m_barriers for pipelines + mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mbar_total] + # Tmem holding buffer + tmem_holding_buf: Int32 + # Smem tensors + # store row max and row sum + sScale: cute.struct.MemRange[Float32, self.q_stage * self.m_block_size * 2] + sO: cute.struct.Align[ + cute.struct.MemRange[self.o_dtype, sO_size], + self.buffer_align_bytes, + ] + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, sQ_size], + self.buffer_align_bytes, + ] + sK: cute.struct.Align[ + # cute.cosize(sK_layout) is correct even in the case of self.uneven_kv_smem + cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + LOG2_E = math.log2(math.e) + if const_expr(self.score_mod is None): + softmax_scale_log2 = softmax_scale * LOG2_E + softmax_scale = None + else: + # NB: If a users passes in a score mod, we want to apply the score-mod in the sm_scaled qk + # But in the original base 10. We hijack softmax_scale_log2 to just be the change of base + # and correctly apply the softmax_scale prior to score_mod in the softmax step + softmax_scale_log2 = LOG2_E + softmax_scale = softmax_scale + + if const_expr(window_size_left is not None): + window_size_left = Int32(window_size_left) + if const_expr(window_size_right is not None): + window_size_right = Int32(window_size_right) + + fastdiv_mods = None + if cutlass.const_expr(aux_tensors is not None): + seqlen_q = cute.size(mQ.shape[0]) // ( + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 + ) + seqlen_k = ( + cute.size(mK.shape[0]) + if const_expr(mPageTable is None) + else mK.shape[0] * mPageTable.shape[1] + ) + seqlen_q_divmod = FastDivmodDivisor(seqlen_q) + seqlen_k_divmod = FastDivmodDivisor(seqlen_k) + fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) + + head_divmod = None + if cutlass.const_expr(self.pack_gqa): + head_divmod = FastDivmodDivisor(self.qhead_per_kvhead) + + self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) + if cutlass.const_expr(self.use_block_sparsity and mPageTable is not None): + raise NotImplementedError("Block sparsity + paged KV not supported on SM100") + + # Launch the kernel synchronously + self.kernel( + mQ, + mK, + mV, + mO, + mLSE, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, + mPageTable, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_O, + softmax_scale_log2, + softmax_scale, + window_size_left, + window_size_right, + learnable_sink, + blocksparse_tensors, + sQ_layout, + sK_layout, + tP_layout, + sV_layout, + sO_layout, + gmem_tiled_copy_O, + tiled_mma_qk, + tiled_mma_pv, + tile_sched_params, + num_splits, + aux_tensors, + fastdiv_mods, + head_divmod, + ).launch( + grid=grid_dim, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + smem=self.shared_storage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + + # GPU device kernel + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, # (s_q, d, h, b) or (total_q, d, h) if there is cu_seqlens_q + mK: cute.Tensor, # (s_k, d, h_k, b_k) or (total_k, d, h_k) if there is cu_seqlens_k or (page_size, d, h_k, num_pages) if there is page_table + mV: cute.Tensor, # (d, s_k, h_k, b_k) or (d, total_k, h_k) if there is cu_seqlens_k or (d, page_size, h_k, num_pages) if there is page_table + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], + mPageTable: Optional[cute.Tensor], + tma_atom_Q: cute.CopyAtom, + tma_atom_K: Optional[cute.CopyAtom], + tma_atom_V: Optional[cute.CopyAtom], + tma_atom_O: Optional[cute.CopyAtom], + softmax_scale_log2: Float32, + softmax_scale: Float32 | None, + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + learnable_sink: Optional[cute.Tensor], + blocksparse_tensors: Optional[BlockSparseTensors], + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + tP_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sO_layout: cute.ComposedLayout, + gmem_tiled_copy_O: Optional[cute.TiledCopy], + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tile_sched_params: ParamsBase, + num_splits: Int32, + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), + head_divmod=None, + ): + """The device kernel implementation of the Fused Multi-Head Attention. + + This kernel coordinates multiple specialized warps to perform different phases of the FMHA computation: + 1. Load warp: Loads Q, K, V data from global memory to shared memory using TMA + 2. MMA warp: Performs matrix multiplications (Q*K^T and P*V) + 3. Softmax warps: Compute softmax normalization on attention scores + 4. Correction warps: Apply adjustments to intermediate results + 5. Epilogue warp: Handles final output transformation and storage + + The kernel implements a complex pipeline with overlapping computation and memory operations, + using tensor memory access (TMA) for efficient data loading, warp specialization for different + computation phases, and optional attention masking. + """ + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + # Prefetch tma descriptor + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_Q) + if const_expr(tma_atom_K is not None): + cpasync.prefetch_descriptor(tma_atom_K) + if const_expr(tma_atom_V is not None): + cpasync.prefetch_descriptor(tma_atom_V) + if const_expr(tma_atom_O is not None): + cpasync.prefetch_descriptor(tma_atom_O) + + # Alloc + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + mbar_ptr = storage.mbar_ptr.data_ptr() + # Use the first N warps to initialize barriers + if warp_idx == 1: + # Init "full" barrier with number of producers, "empty" barrier with number of consumers + for i in cutlass.range_constexpr(self.q_stage): + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_load_q_full_offset + i, 1 + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_load_q_empty_offset + i, len([self.mma_warp_id]) + ) + if warp_idx == 2: + for i in cutlass.range_constexpr(self.q_stage): + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_softmax_corr_empty_offset + i, cute.arch.WARP_SIZE * 4 + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_softmax_corr_full_offset + i, cute.arch.WARP_SIZE * 4 + ) + if warp_idx == 3: + if const_expr(self.s0_s1_barrier): + for i in cutlass.range_constexpr(8): + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_s0_s1_sequence_offset + i, cute.arch.WARP_SIZE + ) + if const_expr(not self.use_correction_warps_for_epi) and warp_idx == 4: + for i in cutlass.range_constexpr(self.q_stage): + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_corr_epi_full_offset + i, + cute.arch.WARP_SIZE * len(self.correction_warp_ids), + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_corr_epi_empty_offset + i, + cute.arch.WARP_SIZE * len(self.epilogue_warp_ids), + ) + if warp_idx == 5: + for i in cutlass.range_constexpr(self.q_stage): + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_P_full_O_rescaled_offset + i, + cute.arch.WARP_SIZE + * (len(self.softmax0_warp_ids) + len(self.correction_warp_ids)), + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_S_full_offset + i, len([self.mma_warp_id]) + ) + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_O_full_offset + i, len([self.mma_warp_id]) + ) + if warp_idx == 6: + for i in cutlass.range_constexpr(self.q_stage): + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_P_full_2_offset + i, + cute.arch.WARP_SIZE * len(self.softmax0_warp_ids), + ) + if warp_idx == 7: + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_tmem_dealloc_offset, + cute.arch.WARP_SIZE + * len( + ( + *self.softmax0_warp_ids, + *self.softmax1_warp_ids, + *self.correction_warp_ids, + ) + ), + ) + # Relying on pipeline_kv constructor to call mbarrier_init_fence and sync + pipeline_kv = self.make_and_init_load_kv_pipeline(mbar_ptr + self.mbar_load_kv_full_offset) + + # Generate smem tensor Q/K/V/O + # (MMA, MMA_Q, MMA_D, PIPE) + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + # (MMA, MMA_K, MMA_D, PIPE) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + # (MMA, MMA_K, MMA_D, PIPE) + # Strip swizzle info to reuse smem + sV = cute.make_tensor(cute.recast_ptr(sK.iterator, sV_layout.inner), sV_layout.outer) + if const_expr(not self.overlap_sO_sQ): + sO = storage.sO.get_tensor(sO_layout.outer, swizzle=sO_layout.inner) + else: + sO = cute.make_tensor(cute.recast_ptr(sQ.iterator, sO_layout.inner, self.o_dtype), sO_layout.outer) + + sScale = storage.sScale.get_tensor(cute.make_layout(self.q_stage * self.m_block_size * 2)) + + thr_mma_qk = tiled_mma_qk.get_slice(0) # default 1SM + thr_mma_pv = tiled_mma_pv.get_slice(0) # default 1SM + + qk_acc_shape = thr_mma_qk.partition_shape_C(self.mma_tiler_qk[:2]) + tStS_fake = thr_mma_qk.make_fragment_C(qk_acc_shape) + # This is a fake tensor, by right need to retrieve tmem_ptr. But we know that we always + # request 512 columns of tmem, so we know that it starts at 0. + tmem_ptr = cute.make_ptr(Float32, 0, mem_space=cute.AddressSpace.tmem, assumed_align=16) + tStS = cute.make_tensor(tmem_ptr, tStS_fake.layout) + + pv_acc_shape = thr_mma_pv.partition_shape_C(self.mma_tiler_pv[:2]) + tOtO = thr_mma_pv.make_fragment_C(pv_acc_shape) + + tStSs = tuple( + cute.make_tensor(tStS.iterator + self.tmem_s_offset[stage], tStS.layout) + for stage in range(self.q_stage) + ) + tOtOs = tuple( + cute.make_tensor(tOtO.iterator + self.tmem_o_offset[stage], tOtO.layout) + for stage in range(self.q_stage) + ) + + tP = cute.make_tensor(tStS.iterator, tP_layout.outer) + tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] + + tOrPs = [ + cute.make_tensor( + tOrP.iterator + + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p_offset[stage], + tOrP.layout, + ) + for stage in range(self.q_stage) + ] + + block_info = BlockInfo( + # This is cta_tiler, not mma_tiler_qk, since we move by block by (2 * mma_tiler[0], mma_tiler[1]) + self.cta_tiler[0], + self.cta_tiler[1], + self.is_causal, + self.is_local, + self.is_split_kv, + window_size_left, + window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + SeqlenInfoCls = partial( + SeqlenInfoQK.create, + seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], + seqlen_k_static=mK.shape[0] + if const_expr(mPageTable is None) + else mK.shape[0] * mPageTable.shape[1], + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, + ) + AttentionMaskCls = partial( + AttentionMask, + self.m_block_size, + self.n_block_size, + window_size_left=window_size_left, + window_size_right=window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) + + # /////////////////////////////////////////////////////////////////////////////// + # EMPTY + # /////////////////////////////////////////////////////////////////////////////// + for i in cutlass.range_constexpr(len(self.empty_warp_ids)): + if warp_idx == self.empty_warp_ids[i]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + + # /////////////////////////////////////////////////////////////////////////////// + # LOAD + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx >= self.load_warp_ids[0] and warp_idx <= self.load_warp_ids[-1]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + self.load( + thr_mma_qk, + thr_mma_pv, + mQ, + mK, + mV, + sQ, + sK, + sV, + mPageTable, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + pipeline_kv, + mbar_ptr, + block_info, + num_splits, + SeqlenInfoCls, + TileSchedulerCls, + blocksparse_tensors, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # MMA + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.mma_warp_id: + # if warp_idx == self.mma_warp_id or warp_idx == self.empty_warp_ids: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + # Alloc tmem buffer + tmem_alloc_cols = Int32(self.tmem_alloc_cols) + if warp_idx == self.mma_warp_id: + cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf) + cute.arch.sync_warp() + + self.mma( + tiled_mma_qk, + tiled_mma_pv, + sQ, + sK, + sV, + tStSs, + tOtOs, + tOrPs, + pipeline_kv, + mbar_ptr, + block_info, + num_splits, + SeqlenInfoCls, + TileSchedulerCls, + blocksparse_tensors, + ) + + # if warp_idx == self.mma_warp_id: + # dealloc tmem buffer + cute.arch.relinquish_tmem_alloc_permit() + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_tmem_dealloc_offset, 0) + tmem_alloc_cols = Int32(self.tmem_alloc_cols) + # Retrieving tmem ptr and make acc + tmem_ptr = cute.arch.retrieve_tmem_ptr( + Float32, + alignment=16, + ptr_to_buffer_holding_addr=storage.tmem_holding_buf, + ) + cute.arch.dealloc_tmem(tmem_ptr, tmem_alloc_cols) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + if const_expr(not self.use_correction_warps_for_epi): + if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + self.epilogue_s2g( + mO, + sO, + gmem_tiled_copy_O, + tma_atom_O, + mbar_ptr, + block_info, + num_splits, + SeqlenInfoCls, + TileSchedulerCls, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Softmax + # /////////////////////////////////////////////////////////////////////////////// + if ( + (const_expr(self.q_stage == 2) and warp_idx <= self.softmax1_warp_ids[-1]) or + (const_expr(self.q_stage == 1) and warp_idx <= self.softmax0_warp_ids[-1]) + ): + # increase register after decreasing + cute.arch.warpgroup_reg_alloc(self.num_regs_softmax) + softmax_loop = partial( + self.softmax_loop, + softmax_scale_log2=softmax_scale_log2, + softmax_scale=softmax_scale, + thr_mma_qk=thr_mma_qk, + sScale=sScale, + mLSE=mLSE, + learnable_sink=learnable_sink, + mbar_ptr=mbar_ptr, + block_info=block_info, + num_splits=num_splits, + SeqlenInfoCls=SeqlenInfoCls, + AttentionMaskCls=AttentionMaskCls, + TileSchedulerCls=TileSchedulerCls, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + head_divmod=head_divmod, + blocksparse_tensors=blocksparse_tensors, + ) + + if const_expr(not self.s0_s1_barrier): + stage = Int32(0 if const_expr(self.q_stage == 1) or warp_idx < self.softmax1_warp_ids[0] else 1) + softmax_loop( + stage=stage, + tStSi=cute.make_tensor( + tStS.iterator + + (self.tmem_s_offset[0] if stage == 0 else self.tmem_s_offset[1]), + tStS.layout, + ), + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + else: + # If there's s0_s1_barrier, it's faster to have 2 WGs having different code + if warp_idx < self.softmax1_warp_ids[0]: + tStSi = cute.make_tensor(tStS.iterator + self.tmem_s_offset[0], tStS.layout) + softmax_loop(stage=0, tStSi=tStSi) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + if warp_idx < self.correction_warp_ids[0] and warp_idx >= self.softmax1_warp_ids[0]: + tStSi = cute.make_tensor(tStS.iterator + self.tmem_s_offset[1], tStS.layout) + softmax_loop(stage=1, tStSi=tStSi) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + + # /////////////////////////////////////////////////////////////////////////////// + # Correction + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx >= self.correction_warp_ids[0] and warp_idx < self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_correction) + self.correction_loop( + thr_mma_qk, + thr_mma_pv, + tStS, + tOtOs, + sScale, + mO, + mLSE, + sO, + learnable_sink, + gmem_tiled_copy_O, + tma_atom_O, + mbar_ptr, + softmax_scale_log2, + block_info, + num_splits, + SeqlenInfoCls, + TileSchedulerCls, + blocksparse_tensors, + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + + return + + @cute.jit + def load( + self, + thr_mma_qk: cute.core.ThrMma, + thr_mma_pv: cute.core.ThrMma, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + mPageTable: Optional[cute.Tensor], + tma_atom_Q: cute.CopyAtom, + tma_atom_K: Optional[cute.CopyAtom], + tma_atom_V: Optional[cute.CopyAtom], + pipeline_kv: cutlass.pipeline.PipelineAsync, + mbar_ptr: cute.Pointer, + block_info: BlockInfo, + num_splits: Int32, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors], + ): + num_load_threads = len(self.load_warp_ids) * cute.arch.WARP_SIZE + tidx = cute.arch.thread_idx()[0] % num_load_threads + q_producer_phase = Int32(1) + kv_producer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.kv_stage + ) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] + gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_qk, mode=[0, 2]), (None, 0)) + + head_idx_kv = ( + head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + ) + if const_expr(mPageTable is None): + if const_expr(not seqlen.has_cu_seqlens_k): + mK_cur, mV_cur = [t[None, None, head_idx_kv, batch_idx] for t in (mK, mV)] + else: + mK_cur = cute.domain_offset((seqlen.offset_k, 0), mK[None, None, head_idx_kv]) + mV_cur = cute.domain_offset((0, seqlen.offset_k), mV[None, None, head_idx_kv]) + gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0)) + gV = cute.local_tile(mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None)) + else: + # Need to keep batch coord None since we'll index into it with page idx + mK_cur, mV_cur = [t[None, None, head_idx_kv, None] for t in (mK, mV)] + gK = cute.local_tile( + mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0, None) + ) + gV = cute.local_tile( + mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None, None) + ) + tSgQ = thr_mma_qk.partition_A(gQ) + tSgK = thr_mma_qk.partition_B(gK) + tOgV = thr_mma_pv.partition_B(gV) + load_Q_fn, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_Q, 0, cute.make_layout(1), tSgQ, sQ + ) + + if const_expr(self.use_tma_KV): + tKsK, tKgK = cpasync.tma_partition( + tma_atom_K, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK, 0, 3), + ) + tVsV, tVgV = cpasync.tma_partition( + tma_atom_V, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tOgV, 0, 3), + ) + paged_kv_manager = None + else: + page_size = mK.shape[0] + paged_kv_manager = PagedKVManager.create( + mPageTable, + mK, + mV, + FastDivmodDivisor(page_size), + batch_idx, + head_idx_kv, + tidx, + seqlen.seqlen_k, + 0, # leftpad_k + self.n_block_size, + self.head_dim_padded, + self.head_dim_v_padded, + num_load_threads, + mK.element_type, + ) + tKsK, tKgK = None, None + tVsV, tVgV = None, None + + load_Q = partial( + self.load_Q, + load_Q_fn, + mbar_ptr + self.mbar_load_q_full_offset, + mbar_ptr + self.mbar_load_q_empty_offset, + phase=q_producer_phase, + ) + # We have to use mbarrier directly in the load for KV instead of replying on + # pipeline_kv, because we could have different number of TMA bytes for K and V + load_K = partial( + self.load_KV, + tma_atom_K, + tKgK, + tKsK, + paged_kv_manager, + sK, + mbar_ptr + self.mbar_load_kv_full_offset, + mbar_ptr + self.mbar_load_kv_empty_offset, + K_or_V="K", + ) + load_V = partial( + self.load_KV, + tma_atom_V, + tVgV, + tVsV, + paged_kv_manager, + sV, + mbar_ptr + self.mbar_load_kv_full_offset, + mbar_ptr + self.mbar_load_kv_empty_offset, + K_or_V="V", + ) + + if const_expr(not self.use_block_sparsity): + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, m_block, split_idx, num_splits + ) + if const_expr(not self.is_split_kv) or n_block_min < n_block_max: + if const_expr(self.use_tma_KV) or tidx < cute.arch.WARP_SIZE: + load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0 + n_block_first = n_block_max - 1 if n_block_max > 0 else 0 + page_idx = ( + mPageTable[batch_idx, n_block_first] + if const_expr(mPageTable is not None and self.use_tma_KV) + else None + ) + if const_expr(not self.use_tma_KV): + paged_kv_manager.load_page_table(n_block_first) + load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0 + kv_producer_state.advance() + if const_expr(self.q_stage == 2) and (const_expr(self.use_tma_KV) or tidx < cute.arch.WARP_SIZE): + load_Q(block=self.q_stage * m_block + 1, stage=1) # Q1 + q_producer_phase ^= 1 + load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0 + kv_producer_state.advance() + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block = n_block_max - 2 - i + page_idx = ( + mPageTable[batch_idx, n_block] + if const_expr(mPageTable is not None and self.use_tma_KV) + else None + ) + if const_expr(not self.use_tma_KV): + paged_kv_manager.load_page_table(n_block) + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx) + load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki + kv_producer_state.advance() + load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Vi + kv_producer_state.advance() + + else: + kv_producer_state, q_producer_phase = produce_block_sparse_loads_sm100( + blocksparse_tensors, + batch_idx, + head_idx, + m_block, + kv_producer_state, + load_Q, + load_K, + load_V, + pipeline_kv, + self.q_stage, + q_producer_phase, + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + + + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + @cute.jit + def mma( + self, + tiled_mma_qk: cute.core.ThrMma, + tiled_mma_pv: cute.core.ThrMma, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + tStSs: Tuple[cute.Tensor, cute.Tensor], + tOtOs: tuple[cute.Tensor], + tOrPs: Tuple[cute.Tensor, cute.Tensor], + pipeline_kv: cutlass.pipeline.PipelineAsync, + mbar_ptr: cute.Pointer, + block_info: BlockInfo, + num_splits: Int32, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors], + ): + tSrQ = tiled_mma_qk.make_fragment_A(sQ) + tSrK = tiled_mma_qk.make_fragment_B(sK) + tOrV = tiled_mma_pv.make_fragment_B(sV) + if const_expr(self.q_stage == 2): + tSrQs = (tSrQ[None, None, None, 0], tSrQ[None, None, None, 1]) + else: + tSrQs = (tSrQ[None, None, None, 0],) + + qk_mma_op, pv_mma_op = tiled_mma_qk.op, tiled_mma_pv.op + + gemm_Si = [ + partial( + sm100_utils.gemm_ptx_partial, + qk_mma_op, + self.tmem_s_offset[stage], + tSrQs[stage], + sA=sQ[None, None, None, stage], + zero_init=True, + ) + for stage in range(self.q_stage) + ] + gemm_Pi = [ + partial( + sm100_utils.gemm_ptx_partial, + pv_mma_op, + self.tmem_o_offset[stage], + tOrPs[stage], + sA=None, + ) + for stage in range(self.q_stage) + ] + + mma_q_consumer_phase = Int32(0) + mma_kv_consumer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.kv_stage + ) + P_full_O_rescaled_phase = Int32(0) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + + block_iter_count = Int32(0) + process_tile = False + + if const_expr(self.use_block_sparsity): + block_iter_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1) + process_tile = block_iter_count > Int32(0) + else: + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + block_iter_count = n_block_max - n_block_min + if const_expr(not self.is_split_kv): + process_tile = True + else: + process_tile = n_block_min < n_block_max + + if process_tile: + for stage in cutlass.range_constexpr(self.q_stage): + # GEMM_QK00 (Q0 * K0 -> S0) or GEMM_QK01 (Q1 * K0 -> S1) + # 1. wait for Q0 / Q1 + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_load_q_full_offset + stage, mma_q_consumer_phase + ) + # 2. wait for K0 + if const_expr(stage == 0): + pipeline_kv.consumer_wait(mma_kv_consumer_state) + tSrKi = tSrK[None, None, None, mma_kv_consumer_state.index] + # We don't need to acquire empty S0 / S1. + # For the first iteration, we don't need to wait as we're guaranteed S0 / S1 + # are empty. For subsequent iterations, the wait happened at the end + # of the while loop. + # 3. gemm + # tiled_mma_qk = sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrKi, zero_init=True) + sK_cur = sK[None, None, None, mma_kv_consumer_state.index] + if const_expr(self.uneven_kv_smem): + sK_cur = self.offset_kv_smem( + sK_cur, mma_kv_consumer_state.index, mma_kv_consumer_state.phase + ) + gemm_Si[stage](tCrB=tSrKi, sB=sK_cur) + # 4. release S0 / S1 + with cute.arch.elect_one(): + tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage) + mma_q_consumer_phase ^= 1 + # 5. release K0 + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + # End of GEMM (Q1 * K0 -> S1) + # Note: Q0 & Q1 are still needed in the seqlen_kv loop + # so we need to release them after the seqlen_kv loop + + # O hasn't been accumulated yet, its first MMA calculation doesn't need to accumulate + block_loop_count = block_iter_count - 1 + O_should_accumulate = False + for i in cutlass.range(block_loop_count, unroll=1): + # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop + # 1. wait for V0 + pipeline_kv.consumer_wait(mma_kv_consumer_state) + mma_kv_release_state = mma_kv_consumer_state.clone() + Vi_index, Vi_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase + tOrVi = tOrV[None, None, None, Vi_index] + for stage in cutlass.range_constexpr(self.q_stage): + # 2. acquire corrected O0/O1_partial and P0 / P1 + # For the first iteration in this work tile, waiting for O0/O1_partial + # means that the correction warps has finished reading tO during + # the last iteration of the previous work tile has finished. + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, + P_full_O_rescaled_phase, + ) + # 3. gemm + # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) + # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) + sV_cur = sV[None, None, None, Vi_index] + if const_expr(self.uneven_kv_smem): + sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase) + gemm_Pi[stage]( + tCrB=tOrVi, + sB=sV_cur, + zero_init=not O_should_accumulate, + mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, + mbar_phase=P_full_O_rescaled_phase, + ) + # 4. release accumulated O0_partial / O1_partial + # Don't need to signal O_full to the correction warps anymore since the + # correction warps wait for the softmax warps anyway. By the time the softmax + # warps finished, S_i for the next iteration must have been done, so O_i-1 + # must have been done as well. + # with cute.arch.elect_one(): + # tcgen05.commit(mbar_ptr + self.mbar_O_full_offset + stage) + # 5. release V(i-1) + if const_expr(stage == self.q_stage - 1): + pipeline_kv.consumer_release(mma_kv_release_state) + mma_kv_release_state.advance() + # End of GEMM_PV00 (P0 * V0 -> O0_partial) + + # GEMM_QK0i (Q0 * Ki -> S0) + # 1. wait for Ki + if const_expr(stage == 0): + mma_kv_consumer_state.advance() + pipeline_kv.consumer_wait(mma_kv_consumer_state) + Ki_index, Ki_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase + # 2. gemm + # Don't need to wait for the softmax warp to have finished reading the previous + # Si, since this gemm is scheduled after the PV gemm, which guaranteed that Si + # has been read and Pi has been written. + # tiled_mma_qk = sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrK[None, None, None, Ki_index], zero_init=True) + sK_cur = sK[None, None, None, Ki_index] + if const_expr(self.uneven_kv_smem): + sK_cur = self.offset_kv_smem(sK_cur, Ki_index, Ki_phase) + gemm_Si[stage](tCrB=tSrK[None, None, None, Ki_index], sB=sK_cur) + # 3. release S0 + with cute.arch.elect_one(): + tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage) + # End of GEMM_QK0i (Q0 * Ki -> S0) + # 4. release Ki + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + P_full_O_rescaled_phase ^= 1 + O_should_accumulate = True + # End of seqlen_kv loop + + # release Q0 & Q1 + with cute.arch.elect_one(): + for stage in cutlass.range_constexpr(self.q_stage): + tcgen05.commit(mbar_ptr + self.mbar_load_q_empty_offset + stage) + + # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop + # 1. wait for V0 + pipeline_kv.consumer_wait(mma_kv_consumer_state) + Vi_index, Vi_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase + tOrVi = tOrV[None, None, None, Vi_index] + for stage in cutlass.range_constexpr(self.q_stage): + # 2. acquire corrected Oi_partial and Pi + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, P_full_O_rescaled_phase + ) + # 3. gemm + # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) + # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) + sV_cur = sV[None, None, None, Vi_index] + if const_expr(self.uneven_kv_smem): + sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase) + gemm_Pi[stage]( + tCrB=tOrVi, + sB=sV_cur, + zero_init=not O_should_accumulate, + mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, + mbar_phase=P_full_O_rescaled_phase, + ) + # 4. release accumulated O0_partial + # We do need O_full here since for the last tile, by the time the softmax warp + # has signaled to the correction warps, the softmax warp has just finished compute + # the row sum of the current tile. It does not guarantee that the 1st tile + # of the next work tile has been computed yet. + with cute.arch.elect_one(): + tcgen05.commit(mbar_ptr + self.mbar_O_full_offset + stage) + # End of GEMM_PV00 (P0 * V0 -> O0_partial) + P_full_O_rescaled_phase ^= 1 + # 5. release Vi_end + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + # End of GEMM_PV1(i_end) (P1 * Vi_end -> O1) + + # Advance to next tile + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + + # for both softmax0 and softmax1 warp group + @cute.jit + def softmax_loop( + self, + stage: int | Int32, + softmax_scale_log2: Float32, + softmax_scale: Float32, + thr_mma_qk: cute.core.ThrMma, + tStSi: cute.Tensor, + sScale: cute.Tensor, + mLSE: Optional[cute.Tensor], + learnable_sink: Optional[cute.Tensor], + mbar_ptr: cute.Pointer, + block_info: BlockInfo, + num_splits: Int32, + SeqlenInfoCls: Callable, + AttentionMaskCls: Callable, + TileSchedulerCls: Callable, + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), + head_divmod=None, + blocksparse_tensors: Optional[BlockSparseTensors] = None, + ): + """Compute softmax on attention scores from QK matrix multiplication. + + This method handles the softmax computation for either the first or second half of the + attention matrix, depending on the 'stage' parameter. It calculates row-wise maximum + and sum values needed for stable softmax computation, applies optional masking, and + transforms raw attention scores into probability distributions. + + The implementation uses specialized memory access patterns and efficient math operations + for computing exp(x) using exp2 functions. It also coordinates pipeline + synchronization between MMA, correction, and sequence processing stages. + """ + tidx = cute.arch.thread_idx()[0] % ( + cute.arch.WARP_SIZE + # * (len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids) + * (len(self.softmax0_warp_ids)) + ) + + tStScale = cute.composition(tStSi, cute.make_layout((self.m_block_size, 1))) + tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2])) + tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1))) + + tilePlikeFP32 = self.mma_tiler_qk[1] // 32 * self.v_dtype.width + tStP_layout = cute.composition( + tStSi.layout, cute.make_layout((self.m_block_size, tilePlikeFP32)) + ) + tStP = cute.make_tensor(tStSi.iterator + self.tmem_s_to_p_offset, tStP_layout) + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + Float32, + ) + thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStSi).get_slice(tidx) + tStS_t2r = thr_tmem_load.partition_S(tStSi) + + tmem_store_scale_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(1)), + Float32, + ) + thr_tmem_store_scale = tcgen05.make_tmem_copy(tmem_store_scale_atom, tStScale).get_slice( + tidx + ) + + tStScale_r2t = thr_tmem_store_scale.partition_D(tStScale) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), + Float32, + ) + thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP).get_slice(tidx) + tStP_r2t = thr_tmem_store.partition_D(tStP) + + mma_si_consumer_phase = Int32(0) + si_corr_producer_phase = Int32(1) + s0_s1_sequence_phase = Int32(1 if stage == 0 else 0) + + # self.warp_scheduler_barrier_init() + + warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 + mbar_s0_s1_sequence_offset = self.mbar_s0_s1_sequence_offset + warp_idx_in_wg + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + + mask = AttentionMaskCls(seqlen) + shared_mask_kwargs = dict( + m_block=self.q_stage * m_block + stage, + thr_mma=thr_mma_qk, + thr_tmem_load=thr_tmem_load, + mask_causal=self.is_causal, + mask_local=self.is_local, + batch_idx=batch_idx, + head_idx=head_idx, + aux_tensors=aux_tensors, + ) + + # Recompute fastdiv_mods if necessary + recompute_fastdiv_mods_q = cutlass.const_expr( + aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q) + ) + recompute_fastdiv_mods_k = cutlass.const_expr( + aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k) + ) + + if cutlass.const_expr(fastdiv_mods is not None): + seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods + fastdiv_mods = ( + seqlen_q_divmod + if not recompute_fastdiv_mods_q + else FastDivmodDivisor(seqlen.seqlen_q), + seqlen_k_divmod + if not recompute_fastdiv_mods_k + else FastDivmodDivisor(seqlen.seqlen_k), + ) + + mask_mod = self.mask_mod if const_expr(self.mask_mod is not None) else None + mask_fn = partial( + mask.apply_mask_sm100, + mask_mod=mask_mod, + fastdiv_mods=fastdiv_mods, + head_divmod=head_divmod, + **shared_mask_kwargs, + ) + if const_expr(self.use_block_sparsity): + # Full blocks dont need mask_mod + mask_fn_none = partial( + mask.apply_mask_sm100, + mask_mod=None, + fastdiv_mods=fastdiv_mods, + head_divmod=head_divmod, + **shared_mask_kwargs, + ) + else: + mask_fn_none = None + + softmax = SoftmaxSm100.create( + softmax_scale_log2, + rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0, + softmax_scale=softmax_scale, + ) + softmax.reset() + + if const_expr(self.use_block_sparsity): + tile_block_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1) + has_work = tile_block_count > Int32(0) + else: + tile_block_count = n_block_max - n_block_min + has_work = const_expr(not self.is_split_kv) or tile_block_count > Int32(0) + + softmax_step = partial( + self.softmax_step, + softmax=softmax, + mbar_ptr=mbar_ptr, + mbar_s0_s1_sequence_offset=mbar_s0_s1_sequence_offset, + thr_mma_qk=thr_mma_qk, + thr_tmem_load=thr_tmem_load, + thr_tmem_store=thr_tmem_store, + thr_tmem_store_scale=thr_tmem_store_scale, + tStS_t2r=tStS_t2r, + tStScale_r2t=tStScale_r2t, + tStP_r2t=tStP_r2t, + sScale=sScale, + stage=stage, + batch_idx=batch_idx, + head_idx=head_idx, + m_block=self.q_stage * m_block + stage, + seqlen=seqlen, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + head_divmod=head_divmod, + ) + + if has_work: + # Softmax acts as the producer: wait until correction signals the stage is empty + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase + ) + si_corr_producer_phase ^= 1 + + # Block sparse or dense iteration + if const_expr(self.use_block_sparsity): + # When aux_tensors exist, Q indices beyond seqlen_q must be wrapped to avoid + # OOB aux_tensor access. Only edge tiles (where m_tile_end > seqlen_q) need this. + if const_expr(aux_tensors is not None): + m_tile_end = (self.q_stage * m_block + stage + 1) * self.m_block_size + check_m_boundary = m_tile_end > seqlen.seqlen_q + else: + check_m_boundary = False + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + empty_tile, + ) = softmax_block_sparse_sm100( + blocksparse_tensors, + batch_idx, + head_idx, + m_block, + softmax_step, + mask_fn, + mask_fn_none, + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + mbar_ptr, + self.mbar_softmax_corr_full_offset, + self.mbar_softmax_corr_empty_offset, + self.mbar_P_full_O_rescaled_offset, + self.mbar_P_full_2_offset, + self.q_stage, + Int32(stage), + check_m_boundary, + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + if not empty_tile: + sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] + if const_expr(mLSE is not None or learnable_sink is not None): + sScale[ + tidx + stage * self.m_block_size + self.m_block_size * 2 + ] = softmax.row_max[0] + # if tidx == 0: + # cute.printf("softmax row sum stage %d: %f, row_max = %f\n", stage, softmax.row_sum[0], softmax.row_max[0]) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) + # if tidx == 0: cute.printf("softmax row sum stage %d: %f\n", stage, softmax.row_sum[0]) + else: + if const_expr(not self.is_split_kv) or tile_block_count > Int32(0): + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block_max - 1, + is_first=True, + mask_fn=partial(mask_fn, mask_seqlen=True), + ) + n_block_max -= 1 + # Next couple of iterations with causal masking + if const_expr(self.is_causal or self.is_local): + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen, m_block, n_block_min + ) + for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): + n_block = n_block_max - 1 - n_tile + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( + softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + ) + n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) + # The remaining iterations have no masking (but may still need mask_mod) + n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( + seqlen, m_block, n_block_min + ) + for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): + n_block = n_block_max - n_tile - 1 + if const_expr(self.mask_mod is not None): + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + else: + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, + ) + # Separate iterations with local masking on the left + if const_expr(self.is_local and block_info.window_size_left is not None): + n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) + for n_tile in cutlass.range(0, n_block_max - n_block_min, unroll=1): + n_block = n_block_max - 1 - n_tile + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( + softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + ) + # Now that we no longer already have the 1st iteration, need mask_seqlen=True here + + # Dense path always writes scale / signals + sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] + if const_expr(mLSE is not None or learnable_sink is not None): + sScale[ + tidx + stage * self.m_block_size + self.m_block_size * 2 + ] = softmax.row_max[0] + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) + + # # Write LSE to gmem + # if const_expr(mLSE is not None): + # acc_O_mn_row_is_zero_or_nan = softmax.row_sum[0] == 0.0 or softmax.row_sum[0] != softmax.row_sum[0] + # scale = ( + # cute.arch.rcp_approx(softmax.row_sum[0] if not acc_O_mn_row_is_zero_or_nan else 1.0) + # ) + # LN2 = math.log(2.0) + # lse = ( + # (softmax.row_max[0] * softmax.scale_log2 + utils.log2f(softmax.row_sum[0])) * LN2 + # if not acc_O_mn_row_is_zero_or_nan else -Float32.inf + # ) + # if const_expr(not seqlen.has_cu_seqlens_q): + # mLSE_cur = mLSE[None, head_idx, batch_idx] + # else: + # mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[None, head_idx]) + # gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block * 2 + stage,)) + # if tidx < seqlen.seqlen_q - (m_block * 2 + stage) * self.m_block_size: + # gLSE[tidx] = lse + + # Advance to next tile + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + @cute.jit + def softmax_step( + self, + mma_si_consumer_phase: Int32, + si_corr_producer_phase: Int32, + s0_s1_sequence_phase: Int32, + n_block: Int32, + softmax: SoftmaxSm100, + mbar_ptr: cute.Pointer, + mbar_s0_s1_sequence_offset: Int32, + thr_mma_qk: cute.core.ThrMma, + thr_tmem_load: cute.CopyAtom, + thr_tmem_store: cute.CopyAtom, + thr_tmem_store_scale: cute.CopyAtom, + tStS_t2r: cute.Tensor, + tStScale_r2t: cute.Tensor, + tStP_r2t: cute.Tensor, + sScale: cute.Tensor, + stage: int | Int32, + batch_idx: Int32, + head_idx: Int32, + m_block: Int32, + seqlen, + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), + head_divmod=None, + mask_fn: Optional[Callable] = None, + is_first: bool = False, + ) -> Tuple[cute.Int32, cute.Int32, cute.Int32]: + """Perform a single step of the softmax computation on a block of attention scores. + + This method processes one block of the attention matrix, computing numerically stable + softmax by first finding the row maximum, subtracting it from all elements, applying + exponential function, and then normalizing by the sum of exponentials. It also handles + optional masking of attention scores. + + The method involves several key operations: + 1. Loading attention scores from tensor memory + 2. Applying optional masking based on position + 3. Computing row-wise maximum values for numerical stability + 4. Transforming scores using exp2(x*scale - max*scale) + 5. Computing row sums for normalization + 6. Coordinating pipeline synchronization between different processing stages + """ + tilePlikeFP32 = self.mma_tiler_qk[1] // Float32.width * self.v_dtype.width + tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2])) + tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1))) + tScP = cute.composition(tScS, cute.make_layout((self.m_block_size, tilePlikeFP32))) + + # Wait for Si + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_S_full_offset + stage, mma_si_consumer_phase) + tSrS_t2r = cute.make_fragment(thr_tmem_load.partition_D(tScS).shape, self.qk_acc_dtype) + cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) + if cutlass.const_expr(self.score_mod is not None): + self.apply_score_mod( + tSrS_t2r, + thr_tmem_load, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + n_block, + softmax, + seqlen, + aux_tensors, + fastdiv_mods, + head_divmod, + ) + + if const_expr(mask_fn is not None): + mask_fn(tSrS_t2r, n_block=n_block) + row_max, acc_scale = softmax.update_row_max(tSrS_t2r.load(), is_first) + + if const_expr(not is_first): + # tSrScale_r2t = cute.make_fragment(thr_tmem_store_scale.partition_S(tScScale).shape, Float32) + # tSrScale_r2t[0] = acc_scale + # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) + # cute.arch.fence_view_async_tmem_store() + thread_idx = thr_tmem_load.thr_idx + sScale[thread_idx + stage * self.m_block_size] = acc_scale + # if thread_idx == 0: cute.printf("softmax acc_scale stage %d: %f, row_max = %f\n", stage, acc_scale, row_max) + # Notify correction wg that row_max is ready + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) + + # if thread_idx == 0 and stage == 0: cute.print_tensor(tSrS_t2r) + # print(tSrS_t2r) + softmax.scale_subtract_rowmax(tSrS_t2r, row_max) + # Sequence barrier wait + if const_expr(self.s0_s1_barrier): + cute.arch.mbarrier_wait( + mbar_ptr + mbar_s0_s1_sequence_offset + stage * 4, s0_s1_sequence_phase + ) + tSrP_r2t_f32 = cute.make_fragment(thr_tmem_store.partition_S(tScP).shape, Float32) + tSrP_r2t = cute.make_tensor( + cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), + tSrS_t2r.layout, + ) + # softmax.scale_apply_exp2_convert(tSrS_t2r, row_max, tSrP_r2t) + softmax.apply_exp2_convert( + tSrS_t2r, + tSrP_r2t, + e2e=mask_fn is None and self.head_dim_padded <= 128, + e2e_freq=self.e2e_freq, + ) + # Sequence barrier arrive + if const_expr(self.s0_s1_barrier): + cute.arch.mbarrier_arrive(mbar_ptr + mbar_s0_s1_sequence_offset + (1 - stage) * 4) + # print(tSrP_r2t_f32, tStP_r2t) + # cute.copy(thr_tmem_store, tSrP_r2t_f32, tStP_r2t) + for i in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2]) // 4 * 3): + cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i]) + cute.arch.fence_view_async_tmem_store() + # Notify mma warp that P is ready + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) + for i in cutlass.range_constexpr( + cute.size(tStP_r2t.shape[2]) // 4 * 3, cute.size(tStP_r2t.shape[2]) + ): + cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i]) + cute.arch.fence_view_async_tmem_store() + # Notify mma warp that the 2nd half of P is ready + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_2_offset + stage) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase + ) + softmax.update_row_sum(tSrS_t2r.load(), acc_scale, is_first) + # acc_scale = cute.arch.exp2(acc_scale_) + return mma_si_consumer_phase ^ 1, si_corr_producer_phase ^ 1, s0_s1_sequence_phase ^ 1 + + @cute.jit + def correction_loop( + self, + thr_mma_qk: cute.core.ThrMma, + thr_mma_pv: cute.core.ThrMma, + tStS: cute.Tensor, + tOtOs: tuple[cute.Tensor], + sScale: cute.Tensor, + mO: cute.Tensor, + mLSE: cute.Tensor, + sO: cute.Tensor, + learnable_sink: Optional[cute.Tensor], + gmem_tiled_copy_O: cute.TiledCopy, + tma_atom_O: cute.CopyAtom, + mbar_ptr: cute.Pointer, + softmax_scale_log2: Float32, + block_info: BlockInfo, + num_splits: Int32, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors] = None, + ): + tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids)) + tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2])) + tStScale_layout = cute.composition(tStS.layout, cute.make_layout((self.m_block_size, 1))) + tStScales = tuple( + cute.make_tensor(tStS.iterator + self.tmem_vec_offset[stage], tStScale_layout) + for stage in range(self.q_stage) + ) + tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1))) + tmem_load_v_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(1)), + self.qk_acc_dtype, + ) + thr_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_v_atom, tStScales[0]).get_slice(tidx) + + tStScales_t2r = [thr_tmem_load_vec.partition_S(tStScales[stage]) for stage in range(self.q_stage)] + tSrScale_t2r_shape = thr_tmem_load_vec.partition_D(tScScale).shape + + # First iter: no correction is required + for stage in cutlass.range_constexpr(self.q_stage): + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) + + softmax_corr_consumer_phase = Int32(0) + o_corr_consumer_phase = Int32(0) + corr_epi_producer_phase = Int32(1) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + + if const_expr(self.is_split_kv): + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] + else: + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx] + gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0)) + + # Default LSE to -inf for invalid split_idx tiles + stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage + + if const_expr(self.use_block_sparsity): + total_block_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1) + has_work = total_block_count > Int32(0) + else: + total_block_count = n_block_max - n_block_min + has_work = const_expr(not self.is_split_kv) or total_block_count > Int32(0) + + if has_work: + # Ignore first signal from softmax as no correction is required + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + 0, softmax_corr_consumer_phase + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 0) + if const_expr(self.q_stage == 2): + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + 1, softmax_corr_consumer_phase + ) + softmax_corr_consumer_phase ^= 1 + + tSrScale_t2r = cute.make_fragment(tSrScale_t2r_shape, Float32) + for i in cutlass.range(total_block_count - 1, unroll=1): + for stage in cutlass.range_constexpr(self.q_stage): + # wait for S0 / S1 + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + stage, + softmax_corr_consumer_phase, + ) + # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) + # cute.arch.fence_view_async_tmem_load() + # scale = tSrScale_t2r[0] + scale = sScale[tidx + stage * self.m_block_size] + should_rescale = cute.arch.vote_ballot_sync(scale < 1.0) != 0 + # should_rescale = True + # if tidx == 0: cute.printf("Correction scale i = %d, for stage %d: %f, should_rescale = %d\n", i, stage, scale, should_rescale) + # Don't need O_full anymore, since by the time softmax has signaled the correction + # warps, S_i must have been done, so O_i-1 must have been done as well. + # cute.arch.mbarrier_wait(mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase) + if should_rescale: + self.correction_rescale( + thr_mma_pv, tOtOs[stage], tidx, scale + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) + if const_expr(self.q_stage == 2): + cute.arch.mbarrier_arrive( + mbar_ptr + self.mbar_softmax_corr_empty_offset + (1 - stage) + ) + else: + cute.arch.mbarrier_arrive( + mbar_ptr + self.mbar_softmax_corr_empty_offset + stage + ) + softmax_corr_consumer_phase ^= 1 + # o_corr_consumer_phase ^= 1 + if const_expr(self.q_stage == 2): + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 1) + # End of seqlen_corr_loop_steps + + # Even in the case of self.overlap_sO_sQ, we can write to stage 0 of sO without + # additional sync because the MMA in the top half must have been done. + # Similarly we can write to stage 1 of sO without additional sync. + learnable_sink_val = [None] * self.q_stage + if const_expr(learnable_sink is not None): + if const_expr(not self.pack_gqa): + sink_val = Float32(learnable_sink[head_idx]) + learnable_sink_val = [sink_val] * self.q_stage + else: # Each thread might have a different sink value due to different q_head + for stage in cutlass.range_constexpr(self.q_stage): + q_head_idx = ( + (self.q_stage * m_block + stage) * self.m_block_size + tidx + ) % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead + learnable_sink_val[stage] = Float32(learnable_sink[q_head_idx]) + for stage in cutlass.range_constexpr(self.q_stage): + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + stage, + softmax_corr_consumer_phase, + ) + # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) + # cute.arch.fence_view_async_tmem_load() + # scale = tSrScale_t2r[0] + row_sum = sScale[tidx + stage * self.m_block_size] + if const_expr(mLSE is not None or learnable_sink is not None): + row_max = sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] + else: + row_max = None + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage) + if const_expr(learnable_sink is not None): + LOG2_E = math.log2(math.e) + sink_val = learnable_sink_val[stage] + if const_expr(not self.is_split_kv) or split_idx == 0: + if row_max == -Float32.inf: + # It's possible to have an empty row with splitKV. + row_max = sink_val * (LOG2_E / softmax_scale_log2) + row_sum = Float32(1.0) + else: + row_sum += utils.exp2f( + sink_val * LOG2_E - row_max * softmax_scale_log2 + ) + acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum + stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan) + scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase + ) + if const_expr(not self.use_correction_warps_for_epi): + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase + ) + self.correction_epilogue( + thr_mma_pv, + tOtOs[stage], + tidx, + stage, + m_block, + seqlen.seqlen_q, + scale, + sO[None, None, stage], + mO_cur, + gO, + gmem_tiled_copy_O, + ) + if const_expr(not self.use_correction_warps_for_epi): + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_full_offset + stage) + # Signal for the next work tile that O buffers in tmem are already read, so + # mma warp can write to them + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) + # if tidx == 0: cute.printf("Correction final scale for stage %d: %f\n", stage, scale) + + o_corr_consumer_phase ^= 1 + softmax_corr_consumer_phase ^= 1 + corr_epi_producer_phase ^= 1 + else: + # WARNING: we need some code before the const_expr, see https://github.com/NVIDIA/cutlass/issues/2781 + if const_expr(self.use_correction_warps_for_epi): + gmem_tiled_copy_O_for_empty_tile = gmem_tiled_copy_O + else: + gmem_tiled_copy_O_for_empty_tile = None + if const_expr(self.use_block_sparsity): + ( + softmax_corr_consumer_phase, + o_corr_consumer_phase, + corr_epi_producer_phase, + ) = handle_block_sparse_empty_tile_correction_sm100( + tidx, + self.q_stage, + self.m_block_size, + self.qhead_per_kvhead, + self.pack_gqa, + self.is_split_kv, + learnable_sink, + mLSE, + seqlen, + m_block, + head_idx, + batch_idx, + split_idx, + sScale, + stats, + self.correction_epilogue, + thr_mma_pv, + tOtOs, + sO, + mbar_ptr, + self.mbar_softmax_corr_full_offset, + self.mbar_softmax_corr_empty_offset, + self.mbar_P_full_O_rescaled_offset, + self.mbar_P_full_2_offset, + self.mbar_corr_epi_full_offset, + self.mbar_corr_epi_empty_offset, + softmax_corr_consumer_phase, + o_corr_consumer_phase, + corr_epi_producer_phase, + softmax_scale_log2, + mO_cur, + gO, + gmem_tiled_copy_O_for_empty_tile, + ) + + if const_expr(mLSE is not None): + if const_expr(not seqlen.has_cu_seqlens_q): + if const_expr(self.is_split_kv): + mLSE_cur = mLSE[None, head_idx, batch_idx, split_idx] + else: + mLSE_cur = mLSE[None, head_idx, batch_idx] + else: + offset = ( + seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + ) + if const_expr(self.is_split_kv): + mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx, split_idx]) + else: + mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) + for stage in cutlass.range_constexpr(self.q_stage): + gLSE = cute.local_tile( + mLSE_cur, (self.m_block_size,), (self.q_stage * m_block + stage,) + ) + row_sum, row_max, acc_O_mn_row_is_zero_or_nan = stats[stage] + # if tidx == 0 and stage <= 1: + # cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan) + LN2 = math.log(2.0) + lse = ( + (row_max * softmax_scale_log2 + utils.log2f(row_sum)) * LN2 + if not acc_O_mn_row_is_zero_or_nan + else -Float32.inf + ) + seqlen_q = ( + seqlen.seqlen_q + if const_expr(not self.pack_gqa) + else seqlen.seqlen_q * self.qhead_per_kvhead + ) + if tidx < seqlen_q - (self.q_stage * m_block + stage) * self.m_block_size: + # This actually just works with PackGQA too + gLSE[tidx] = lse + + # Advance to next tile + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + @cute.jit + def correction_rescale( + self, + thr_mma: cute.core.ThrMma, + tOtO: cute.Tensor, + tidx: Int32, + scale: Float32, + ): + """Rescale intermediate attention results based on softmax normalization factor. + + This method performs a crucial correction step in the attention computation pipeline. + When processing attention in blocks, the softmax normalization factors may change + as new blocks are processed. This method rescales previously computed partial + output values to account for updated normalization factors. + + The implementation uses efficient tensor memory operations to: + 1. Load existing partial attention output from tensor memory + 2. Apply the scaling factor to all elements + 3. Store the rescaled results back to tensor memory + """ + tOcO = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler_pv[:2])) + corr_tile_size = 16 # tuneable parameter + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + tOtO_i = cute.composition(tOtO, cute.make_layout((self.m_block_size, corr_tile_size))) + tOcO_i = cute.composition(tOcO, cute.make_layout((self.m_block_size, corr_tile_size))) + thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tOtO_i).get_slice(tidx) + thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tOtO_i).get_slice(tidx) + tOtO_t2r = thr_tmem_load.partition_S(tOtO_i) + tOrO_t2r_shape = thr_tmem_load.partition_D(tOcO_i).shape + tOtO_r2t = thr_tmem_store.partition_D(tOtO_i) + + frg_count = self.head_dim_v_padded // corr_tile_size + tOrO_frg = cute.make_fragment((tOrO_t2r_shape, frg_count), self.pv_acc_dtype) + for i in cutlass.range_constexpr(frg_count): + tOrO_frg = cute.make_fragment(tOrO_t2r_shape, self.pv_acc_dtype) + tOtO_t2r_i = cute.make_tensor(tOtO_t2r.iterator + i * corr_tile_size, tOtO_t2r.layout) + cute.copy(thr_tmem_load, tOtO_t2r_i, tOrO_frg) + for j in cutlass.range(0, cute.size(tOrO_frg), 2, unroll_full=True): + tOrO_frg[j], tOrO_frg[j + 1] = utils.mul_packed_f32x2( + (tOrO_frg[j], tOrO_frg[j + 1]), + (scale, scale), + ) + tOtO_r2t_i = cute.make_tensor(tOtO_r2t.iterator + i * corr_tile_size, tOtO_r2t.layout) + cute.copy(thr_tmem_store, tOrO_frg, tOtO_r2t_i) + cute.arch.fence_view_async_tmem_store() + + @cute.jit + def correction_epilogue( + self, + thr_mma: cute.core.ThrMma, + tOtO: cute.Tensor, + tidx: Int32, + stage: Int32, + m_block: Int32, + seqlen_q: Int32, + scale: Float32, + sO: cute.Tensor, + mO_cur: Optional[cute.Tensor] = None, + gO: Optional[cute.Tensor] = None, + gmem_tiled_copy_O: Optional[cute.TiledCopy] = None, + ): + """Apply final scaling and transformation to attention output before writing to global memory. + + This correction_epilogue function handles the final processing step for attention output values. + It applies a scaling factor to the accumulated attention results and prepares the + data for efficient transfer back to global memory. + + The method performs: + 1. Loading of accumulated attention results from tensor memory + 2. Application of the final output scaling factor + 3. Type conversion if necessary (typically from higher precision accumulator to output precision) + 4. Reorganization of data for optimal memory access patterns + 5. Preparation for efficient TMA store operations + + :param thr_mma: Thread MMA operation for the computation + :type thr_mma: cute.core.ThrMma + :param tOtO: Tensor containing accumulated attention output + :type tOtO: cute.Tensor + :param scale: Final scaling factor to apply to the output + :type scale: Float32 + :param sO: Shared memory tensor for the final output + :type sO: cute.Tensor + """ + + corr_tile_size = 32 * 8 // self.o_dtype.width + tOsO = thr_mma.partition_C(sO) + tOcO = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler_pv[:2])) + + tOtO_i = cute.logical_divide(tOtO, cute.make_layout((self.m_block_size, corr_tile_size))) + tOcO_i = cute.logical_divide(tOcO, cute.make_layout((self.m_block_size, corr_tile_size))) + tOsO_i = cute.logical_divide(tOsO, cute.make_layout((self.m_block_size, corr_tile_size))) + + epi_subtile = (self.epi_tile[0], corr_tile_size) + tmem_copy_atom = sm100_utils_basic.get_tmem_load_op( + self.mma_tiler_pv, + self.o_layout, + self.o_dtype, + self.pv_acc_dtype, + epi_subtile, + use_2cta_instrs=False, + ) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_i[(None, None), 0]).get_slice( + tidx + ) + thr_tmem_load = tiled_tmem_load.get_slice(tidx) + smem_copy_atom = sm100_utils_basic.get_smem_store_op( + self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load + ) + tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load) + + tOtO_t2r = thr_tmem_load.partition_S(tOtO_i[(None, None), None]) + tOsO_s2r = thr_tmem_load.partition_D(tOsO_i[(None, None), None]) + tOcO_t2r = thr_tmem_load.partition_D(tOcO_i[(None, None), None]) + for i in cutlass.range_constexpr(self.head_dim_v_padded // corr_tile_size): + tOtO_t2r_i = tOtO_t2r[None, 0, 0, i] + tOsO_r2s_i = tOsO_s2r[None, 0, 0, i] + tOrO_frg = cute.make_fragment(tOcO_t2r[None, 0, 0, i].shape, self.pv_acc_dtype) + cute.copy(tiled_tmem_load, tOtO_t2r_i, tOrO_frg) + for j in cutlass.range_constexpr(0, cute.size(tOrO_frg), 2): + tOrO_frg[j], tOrO_frg[j + 1] = utils.mul_packed_f32x2( + (tOrO_frg[j], tOrO_frg[j + 1]), + (scale, scale), + ) + tOrO_frg_cvt = cute.make_fragment(tOrO_frg.shape, self.o_dtype) + tOrO_frg_cvt.store(tOrO_frg.load().to(self.o_dtype)) + cute.copy(tiled_smem_store, tOrO_frg_cvt, tOsO_r2s_i) + # fence view async shared + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + + if const_expr(self.use_correction_warps_for_epi): + assert(not self.use_tma_O) + assert(gmem_tiled_copy_O is not None) + cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=len(self.epilogue_warp_ids) * cute.arch.WARP_SIZE) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOsO = gmem_thr_copy_O.partition_S(sO) + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) + tOgO = gmem_thr_copy_O.partition_D(gO) + tOcO = gmem_thr_copy_O.partition_S(cO) + t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=mO_cur.shape[1]) + pack_gqa = PackGQA( + self.m_block_size, + self.head_dim_v_padded, + self.check_hdim_v_oob, + self.qhead_per_kvhead, + ) + + # load acc O from smem to rmem for wider vectorization + tOrO = cute.make_fragment_like(tOsO, self.o_dtype) + cute.autovec_copy(tOsO, tOrO) + # copy acc O from rmem to gmem + if const_expr(not self.pack_gqa): + for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + if ( + t0OcO[0, rest_m, 0][0] + < seqlen_q + - (self.q_stage * m_block + stage) * self.m_block_size + - tOcO[0][0] + ): + cute.copy( + gmem_tiled_copy_O, + tOrO[None, rest_m, None], + tOgO[None, rest_m, None, self.q_stage * m_block + stage], + pred=tOpO[None, rest_m, None] + if const_expr(self.check_hdim_v_oob) + else None, + ) + else: + pack_gqa.store_O( + mO_cur, + tOrO, + gmem_tiled_copy_O, + tidx, + self.q_stage * m_block + stage, + seqlen_q, + ) + + @cute.jit + def epilogue_s2g( + self, + mO: cute.Tensor, + sO: cute.Tensor, + gmem_tiled_copy_O: cute.TiledCopy, + tma_atom_O: Optional[cute.CopyAtom], + mbar_ptr: cute.Pointer, + block_info: BlockInfo, + num_splits: int, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + epi_consumer_phase = Int32(0) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + + if const_expr(not self.is_split_kv) or n_block_min < n_block_max: + if const_expr(self.is_split_kv): + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] + else: + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx] + gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0)) + if const_expr(self.use_tma_O): + store_O, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_O, 0, cute.make_layout(1), sO, gO + ) + for stage in cutlass.range_constexpr(self.q_stage): + # wait from corr, issue tma store on smem + # 1. wait for O0 / O1 final + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase + ) + # 2. copy O0 / O1 to gmem + store_O(src_idx=stage, dst_idx=self.q_stage * m_block + stage) + cute.arch.cp_async_bulk_commit_group() + for stage in cutlass.range_constexpr(self.q_stage): + # Ensure O0 / O1 buffer is ready to be released + if const_expr(self.q_stage == 2): + cute.arch.cp_async_bulk_wait_group(1 - stage, read=True) + else: + cute.arch.cp_async_bulk_wait_group(0, read=True) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) + else: + tidx = cute.arch.thread_idx()[0] % ( + cute.arch.WARP_SIZE * len(self.epilogue_warp_ids) + ) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOsO = gmem_thr_copy_O.partition_S(sO) + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) + tOgO = gmem_thr_copy_O.partition_D(gO) + tOcO = gmem_thr_copy_O.partition_S(cO) + t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) + pack_gqa = PackGQA( + self.m_block_size, + self.head_dim_v_padded, + self.check_hdim_v_oob, + self.qhead_per_kvhead, + ) + for stage in cutlass.range_constexpr(self.q_stage): + # wait from corr, issue tma store on smem + # 1. wait for O0 / O1 final + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase + ) + # 2. copy O0 / O1 to gmem + # load acc O from smem to rmem for wider vectorization + tOrO = cute.make_fragment_like(tOsO[None, None, None, 0], self.o_dtype) + cute.autovec_copy(tOsO[None, None, None, stage], tOrO) + # copy acc O from rmem to gmem + if const_expr(not self.pack_gqa): + for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + if ( + t0OcO[0, rest_m, 0][0] + < seqlen.seqlen_q + - (self.q_stage * m_block + stage) * self.m_block_size + - tOcO[0][0] + ): + cute.copy( + gmem_tiled_copy_O, + tOrO[None, rest_m, None], + tOgO[None, rest_m, None, self.q_stage * m_block + stage], + pred=tOpO[None, rest_m, None] + if const_expr(self.check_hdim_v_oob) + else None, + ) + else: + pack_gqa.store_O( + mO_cur, + tOrO, + gmem_tiled_copy_O, + tidx, + self.q_stage * m_block + stage, + seqlen.seqlen_q, + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) + + epi_consumer_phase ^= 1 + + # Advance to next tile + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + def load_Q( + self, + load_Q_fn: Callable, + mbar_full_ptr: cute.Pointer, + mbar_empty_ptr: cute.Pointer, + block: Int32, + stage: int, + phase: Int32, + ): + cute.arch.mbarrier_wait(mbar_empty_ptr + stage, phase) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx(mbar_full_ptr + stage, self.tma_copy_bytes["Q"]) + load_Q_fn(src_idx=block, dst_idx=stage, tma_bar_ptr=mbar_full_ptr + stage) + + @cute.jit + def load_KV( + self, + tma_atom: Optional[cute.CopyAtom], + tXgX: Optional[cute.Tensor], + tXsX: Optional[cute.Tensor], + paged_kv_manager: Optional[PagedKVManager], + sX: cute.Tensor, + mbar_full_ptr: cute.Pointer, + mbar_empty_ptr: cute.Pointer, + block: Int32, + producer_state: cutlass.pipeline.PipelineState, + K_or_V: Literal["K", "V"], + page_idx: Optional[Int32] = None, + ): + assert K_or_V in ("K", "V") + stage, phase = producer_state.index, producer_state.phase + cute.arch.mbarrier_wait(mbar_empty_ptr + stage, phase) + if const_expr(K_or_V == "K" and self.uneven_kv_smem): + # Before this round, the smem location was occupied by V, which is smaller than + # K. So we need to wait for the stage after that (stage 1) to be empty as well. + if stage == 0: + cute.arch.mbarrier_wait(mbar_empty_ptr + 1, phase) + + if const_expr(self.use_tma_KV): + assert ( + tXgX is not None and + tXsX is not None and + tma_atom is not None + ) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + mbar_full_ptr + stage, self.tma_copy_bytes[K_or_V], + ) + tXsX_cur = tXsX[None, stage] + if const_expr(self.uneven_kv_smem): + # Since this is the producer_state, the phase starts at 1, so we have to invert it + tXsX_cur = self.offset_kv_smem(tXsX_cur, stage, phase ^ 1) + # Currently we assume that page_size == n_block_size so we index into tXgX with block = 0 + tXgX_cur = tXgX[None, block] if const_expr(page_idx is None) else tXgX[None, 0, page_idx] + cute.copy(tma_atom, tXgX_cur, tXsX_cur, tma_bar_ptr=mbar_full_ptr + stage) + else: + assert paged_kv_manager is not None + paged_kv_manager.load_KV(block, sX[None, None, None, stage], K_or_V) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_mbarrier_arrive_noinc(mbar_full_ptr + stage) + + @cute.jit + def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32): + if const_expr(self.uneven_kv_smem): + # smem layout is [smem_large, smem_small, smem_large], and the current stride is + # (smem_large + smem_small) // 2. So for stage == 1, move right by offset if + # phase == 0, or left by offset if phase == 1. + offset = 0 if stage != 1 else self.uneven_kv_smem_offset * (1 - 2 * phase) + return cute.make_tensor(sX.iterator + offset, sX.layout) + else: + return sX + + def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): + load_kv_consumer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + if self.use_tma_KV: + load_kv_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len(self.load_warp_ids) + ) + return cutlass.pipeline.PipelineTmaUmma.create( + barrier_storage=load_kv_mbar_ptr, + num_stages=self.kv_stage, + producer_group=load_kv_producer_group, + consumer_group=load_kv_consumer_group, + tx_count=self.tma_copy_bytes["K"], + ) + else: + load_kv_producer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, len(self.load_warp_ids) * cute.arch.WARP_SIZE + ) + return cutlass.pipeline.PipelineAsyncUmma.create( + num_stages=self.kv_stage, + producer_group=load_kv_producer_group, + consumer_group=load_kv_consumer_group, + barrier_storage=load_kv_mbar_ptr, + ) + + # @cute.jit + # def warp_scheduler_barrier_init(self): + # warp_group_idx = utils.canonical_warp_group_idx(sync=False) + # if warp_group_idx == 0: + # cute.arch.barrier_arrive( + # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1), number_of_threads=2 * 128, + # ) + + # def warp_scheduler_barrier_sync(self): + # cute.arch.barrier( + # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + utils.canonical_warp_group_idx(sync=False), + # number_of_threads=2 * 128 + # ) + + # def warp_scheduler_barrier_arrive(self): + # cur_wg = utils.canonical_warp_group_idx(sync=False) + # next_wg = 1 - cur_wg + # cute.arch.barrier_arrive( + # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, number_of_threads=2 * 128, + # ) + + @cute.jit + def apply_score_mod( + self, + tSrS_t2r, + thr_tmem_load, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + n_block, + softmax, + seqlen: SeqlenInfoQK, + aux_tensors=None, + fastdiv_mods=(None, None), + head_divmod=None, + ): + """Apply score modification for SM100 (constant q_idx).""" + # Prepare index tensor with extra partition + cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size)) + cS = cute.domain_offset((m_block * self.m_block_size, n_block * self.n_block_size), cS) + tScS = thr_mma_qk.partition_C(cS) + tScS_t2r = thr_tmem_load.partition_D(tScS) + + # Shared q_idx for all scores + q_idx_logical = tScS_t2r[0][0] + + # For Pack-GQA, compute the logical head index for this tile + if cutlass.const_expr(self.pack_gqa): + assert head_divmod is not None + # Building up the logical q_head idx: final_q_head = kv_head * qhead_per_kvhead + (q_physical % qhead_per_kvhead) + q_physical = q_idx_logical + q_idx_logical, head_offset = divmod(q_physical, head_divmod) + head_idx = head_idx * self.qhead_per_kvhead + head_offset + + if cutlass.const_expr(aux_tensors is not None): + seqlen_q_divmod, _ = fastdiv_mods + _, q_idx_logical = divmod(q_idx_logical, seqlen_q_divmod) + + apply_score_mod_inner( + tSrS_t2r, + tScS_t2r, + self.score_mod, + batch_idx, + head_idx, + softmax.softmax_scale, + self.vec_size, + self.qk_acc_dtype, + aux_tensors, + fastdiv_mods, + seqlen_info=seqlen, + constant_q_idx=q_idx_logical, + qhead_per_kvhead=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1, + ) diff --git a/flash_attn/cute/hopper_helpers.py b/flash_attn/cute/hopper_helpers.py new file mode 100644 index 00000000000..c6a1c301904 --- /dev/null +++ b/flash_attn/cute/hopper_helpers.py @@ -0,0 +1,101 @@ +# Copyright (c) 2025, Tri Dao. +from typing import Type, Union, Optional +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Float32, Boolean, const_expr +from cutlass.cute.nvgpu import warpgroup +from cutlass.cutlass_dsl import Numeric, dsl_user_op +from cutlass.utils import LayoutEnum +import cutlass.utils.hopper_helpers as sm90_utils_og + + +@cute.jit +def gemm( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + zero_init: cutlass.Constexpr[bool] = False, + wg_wait: cutlass.Constexpr[int] = 0, + # A_in_regs: cutlass.Constexpr[bool] = False, + swap_AB: cutlass.Constexpr[bool] = False, +) -> None: + if const_expr(swap_AB): + gemm(tiled_mma, acc, tCrB, tCrA, zero_init=zero_init, wg_wait=wg_wait, swap_AB=False) + else: + warpgroup.fence() + # We make a new mma_atom since we'll be modifying its attribute (accumulate). + # Otherwise the compiler complains "operand #0 does not dominate this use" + mma_atom = cute.make_mma_atom(tiled_mma.op) + mma_atom.set(warpgroup.Field.ACCUMULATE, not zero_init) + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc) + mma_atom.set(warpgroup.Field.ACCUMULATE, True) + warpgroup.commit_group() + if const_expr(wg_wait >= 0): + warpgroup.wait_group(wg_wait) + + +def gemm_zero_init( + tiled_mma: cute.TiledMma, + shape: cute.Shape, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + wg_wait: int = -1, + swap_AB: bool = False, +) -> cute.Tensor: + if const_expr(swap_AB): + return gemm_zero_init( + tiled_mma, shape[::-1], tCrB, tCrA, B_idx, A_idx, wg_wait, swap_AB=False + ) + else: + acc = cute.make_fragment(tiled_mma.partition_shape_C(shape), Float32) + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait) + return acc + + +def gemm_w_idx( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + zero_init: Boolean, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + wg_wait: int = -1, + swap_AB: bool = False, +) -> None: + if const_expr(swap_AB): + gemm_w_idx(tiled_mma, acc, tCrB, tCrA, zero_init, B_idx, A_idx, wg_wait, swap_AB=False) + else: + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait) + + +@dsl_user_op +def make_smem_layout( + dtype: Type[Numeric], + layout: LayoutEnum, + shape: cute.Shape, + stage: Optional[int] = None, + *, + loc=None, + ip=None, +) -> Union[cute.Layout, cute.ComposedLayout]: + major_mode_size = shape[1] if layout.is_n_major_c() else shape[0] + smem_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_og.get_smem_layout_atom(layout, dtype, major_mode_size), + dtype, + ) + order = (1, 0, 2) if const_expr(layout.is_m_major_c()) else (0, 1, 2) + smem_layout_staged = cute.tile_to_shape( + smem_layout_atom, + cute.append(shape, stage) if const_expr(stage is not None) else shape, + order=order if const_expr(stage is not None) else order[:2], + ) + return smem_layout_staged diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py new file mode 100644 index 00000000000..8d240698ce9 --- /dev/null +++ b/flash_attn/cute/interface.py @@ -0,0 +1,1758 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. + +# Supported features: +# - BF16 & FP16 dtype +# - noncausal & causal attention +# - MHA, GQA, MQA +# - hdim 64, 96, 128. +# - (hdim_qk, hdim_v) = (192, 128) for Blackwell (i.e. DeepSeek shape) +# - varlen +# - sliding window +# - bwd pass for Ampere (will also run on Hopper/Blackwell, but will be slow) + +# Features not supported yet: +# - split (i.e. FlashDecoding) +# - tuned block sizes +# - paged KV +# - append KV to existing KV cache +# - FP8 +# - bwd pass optimized for Hopper/Blackwell + +import math +from functools import lru_cache +from typing import Optional, Tuple, Callable + +import torch + + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute + +from flash_attn.cute import utils +from flash_attn.cute.cute_dsl_utils import to_cute_tensor +from flash_attn.cute.flash_fwd import FlashAttentionForwardSm90 +from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 +from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess +from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 +from flash_attn.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 +from flash_attn.cute.flash_bwd_sm100 import FlashAttentionBackwardSm100 +from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess +from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine + +from flash_attn.cute.block_sparsity import ( + BlockSparseTensorsTorch, + to_cute_block_sparse_tensors, + normalize_block_sparse_tensors, + get_block_sparse_expected_shapes, + get_block_sparse_expected_shapes_bwd, + get_block_sparse_broadcast_pattern, +) + +@lru_cache(maxsize=None) +def _get_device_capability(): + """Cached device capability check.""" + return torch.cuda.get_device_capability()[0] + +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +def _validate_tensor(t, name, expected_shape, expected_dtype, expected_device): + assert t.shape == expected_shape, f"{name} shape {t.shape} != expected {expected_shape}" + assert t.dtype == expected_dtype, f"{name} dtype {t.dtype} != expected {expected_dtype}" + assert t.device == expected_device, f"{name} device {t.device} != expected {expected_device}" + assert t.is_cuda, f"{name} must be on CUDA" + + +torch2cute_dtype_map = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, +} + + +def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits): + # If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512. + if num_n_blocks <= 4: + return 1 + + # NOTE: We should revisit this heuristic after persistence is supported for split KV. + # Sometimes, it's ideal to over-schedule splits for better efficiency. + return min(num_SMs // total_mblocks, max_splits, num_n_blocks) + + +def _flash_attn_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + page_table: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + softcap: Optional[float] = None, + window_size_left: Optional[int] = None, + window_size_right: Optional[int] = None, + learnable_sink: Optional[torch.Tensor] = None, + # m_block_size: int = 128, + # n_block_size: int = 64, + # num_threads: int = 128, + m_block_size: int = 128, + n_block_size: int = 128, + num_threads: int = 384, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + _compute_capability: Optional[int] = None, + score_mod: Optional[Callable] = None, + mask_mod: Optional[Callable] = None, + block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, + return_lse: bool = False, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + aux_tensors: Optional[list[torch.Tensor]] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass for FlashAttention. + + Args: + ... + score_mod: A callable that takes the attention scores and applies a modification. + mask_mod: A callable that takes token position information and selectively masks + block_sparse_tensors: A tuple of tensors used for block sparsity. + return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate + out: Optional pre-allocated output tensor. If None, will be allocated internally. + lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed. + aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel. + """ + q, k, v = [maybe_contiguous(t) for t in (q, k, v)] + num_head, head_dim = q.shape[-2:] + if cu_seqlens_q is None: + batch_size, seqlen_q = q.shape[:2] + total_q = batch_size * seqlen_q + else: + batch_size = cu_seqlens_q.shape[0] - 1 + seqlen_q = None + total_q = q.shape[0] + if page_table is not None: + assert cu_seqlens_k is None, "page_table is not supported with cu_seqlens_k" + assert page_table.dtype == torch.int32, "page_table must be int32" + assert page_table.stride(-1) == 1, "page_table must be contiguous in the last dimension" + max_num_pages_per_seq = page_table.shape[1] + assert page_table.shape == (batch_size, max_num_pages_per_seq) + num_pages, page_size = k.shape[:2] + seqlen_k = num_pages * page_size + else: + num_pages, page_size = None, None + seqlen_k = k.shape[-3] + num_head_kv = k.shape[-2] + head_dim_v = v.shape[-1] + if cu_seqlens_k is None: + if page_table is None: + assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) + assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) + else: + assert k.shape == (num_pages, page_size, num_head_kv, head_dim) + assert v.shape == (num_pages, page_size, num_head_kv, head_dim_v) + else: + assert k.shape == (seqlen_k, num_head_kv, head_dim) + assert v.shape == (seqlen_k, num_head_kv, head_dim_v) + assert cu_seqlens_k.shape == (batch_size + 1,), ( + "cu_seqlens_k must have shape (batch_size + 1,)" + ) + + if cu_seqlens_q is not None: + assert cu_seqlens_q.shape == (batch_size + 1,), ( + "cu_seqlens_q must have shape (batch_size + 1,)" + ) + assert seqused_q is None or seqused_q.shape == (batch_size,), ( + "seqused_q must have shape (batch_size,)" + ) + assert seqused_k is None or seqused_k.shape == (batch_size,), ( + "seqused_k must have shape (batch_size,)" + ) + assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" + assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype" + for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]: + if t is not None: + assert t.dtype == torch.int32, ( + "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32" + ) + assert t.stride(0) == 1, ( + "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous" + ) + if learnable_sink is not None: + assert learnable_sink.shape == (num_head,) + assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16" + + assert all( + t is None or t.is_cuda + for t in ( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + page_table, + learnable_sink, + ) + ), "inputs must be on CUDA device" + assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" + assert head_dim <= 256, "head_dim must be less than or equal to 256" + alignment = 16 // q.element_size() + assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}" + assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}" + if softmax_scale is None: + softmax_scale = 1.0 / math.sqrt(head_dim) + if softcap == 0.0: + softcap = None + qhead_per_kvhead = num_head // num_head_kv + if pack_gqa is None: + pack_gqa = qhead_per_kvhead > 1 + + out_torch_dtype = q.dtype + device = q.device + q_batch_seqlen_shape = (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,) + lse_shape = (batch_size, num_head, seqlen_q) if cu_seqlens_q is None else (num_head, total_q) + requires_grad = q.requires_grad or k.requires_grad or v.requires_grad + + if out is None: + out = torch.empty( + *q_batch_seqlen_shape, num_head, head_dim_v, dtype=out_torch_dtype, device=device + ) + else: + _validate_tensor(out, "out", (*q_batch_seqlen_shape, num_head, head_dim_v), out_torch_dtype, device) + + if lse is None: + lse = ( + torch.empty(lse_shape, dtype=torch.float32, device=device) + if requires_grad or return_lse + else None + ) + elif lse is not None: + _validate_tensor(lse, "lse", lse_shape, torch.float32, device) + + dtype = torch2cute_dtype_map[q.dtype] + compute_capability = ( + _get_device_capability() + if _compute_capability is None + else _compute_capability + ) + + assert compute_capability in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x" + + use_block_sparsity = block_sparse_tensors is not None + + if mask_mod is None: + if causal: + window_size_right = 0 + local = window_size_left is not None or window_size_right is not None + if window_size_left is not None or window_size_right is not None: + if window_size_left is None and window_size_right == 0: + causal, local = True, False + window_size_right = None + else: + causal, local = False, True + else: + causal, local = False, False + + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + if compute_capability == 9: # TODO: tune block size according to hdim. + if head_dim == head_dim_v == 128 and not causal and not local and not use_block_sparsity: + n_block_size = 192 + + if compute_capability in [10, 11]: + if ( + pack_gqa + and (128 % qhead_per_kvhead != 0) + ): + pack_gqa = False + # TODO: fix GQA + SplitKV + non-varlen + if pack_gqa and num_splits != 1 and cu_seqlens_q is None: + pack_gqa = False + + if max_seqlen_q is None: + max_seqlen_q = seqlen_q if cu_seqlens_q is None else total_q + if max_seqlen_k is None: + max_seqlen_k = seqlen_k + seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead + if compute_capability == 10: + q_stage = 2 if seqlen_q_packgqa > m_block_size else 1 + else: + q_stage = 1 + + if num_splits < 1: + m_block_size_effective = q_stage * m_block_size + seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, window_size_right + window_size_left + 1 + m_block_size)) + num_n_blocks = (seqlen_k_loaded + n_block_size - 1) // n_block_size + num_m_blocks = (seqlen_q_packgqa + m_block_size_effective - 1) // m_block_size_effective + total_mblocks = batch_size * num_head_kv * num_m_blocks + num_splits = num_splits_heuristic( + total_mblocks, + torch.cuda.get_device_properties(device).multi_processor_count, + num_n_blocks, + 128, + ) + + is_split_kv = num_splits > 1 + if is_split_kv: + out_partial = torch.empty(num_splits, *q_batch_seqlen_shape, num_head, head_dim_v, dtype=torch.float32, device=device) + lse_partial = torch.empty(num_splits, *lse_shape, dtype=torch.float32, device=device) + + # hash score and mask mods for compile cache + score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False + mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False + + if softcap is not None: + assert score_mod is None, "softcap and score_mod cannot be used together" + score_mod = utils.create_softcap_scoremod(softcap) + + is_varlen = ( + cu_seqlens_q is not None + or cu_seqlens_k is not None + or seqused_q is not None + or seqused_k is not None + ) + + if mask_mod is not None: + if is_varlen: + raise NotImplementedError( + "mask_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR." + ) + + if use_block_sparsity: + if is_varlen: + raise NotImplementedError( + "Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR." + ) + # NB: pack_gqa requires block sparse head dim == 1 (broadcasted) + if pack_gqa and block_sparse_tensors.mask_block_cnt.shape[1] != 1: + pack_gqa = False + if is_split_kv: + raise NotImplementedError( + "Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split." + ) + + # See get_broadcast_dims for why this is needed in compile key + block_sparse_broadcast_pattern = None + normalized_block_sparse_tensors = None + if block_sparse_tensors is not None: + if seqlen_q is None: + raise ValueError("Block sparsity requires fixed-length sequences (seqlen_q must be known).") + expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes( + batch_size, num_head, seqlen_q, seqlen_k, + m_block_size, n_block_size, q_stage, + ) + normalized_block_sparse_tensors = normalize_block_sparse_tensors( + block_sparse_tensors, + expected_count_shape=expected_count_shape, + expected_index_shape=expected_index_shape, + ) + block_sparse_broadcast_pattern = get_block_sparse_broadcast_pattern( + normalized_block_sparse_tensors + ) + + compile_key = ( + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + causal, + score_mod_hash, + mask_mod_hash, + use_block_sparsity, + block_sparse_broadcast_pattern, + len(aux_tensors) if aux_tensors is not None else 0, + lse is None, + cu_seqlens_q is None, + cu_seqlens_k is None, + seqused_q is None, + seqused_k is None, + page_table is not None, + window_size_left is not None, + window_size_right is not None, + learnable_sink is not None, + m_block_size, + n_block_size, + q_stage, + num_threads, + is_split_kv, + pack_gqa, + compute_capability, + page_size not in [None, 128], # paged KV non-TMA + ) + if compile_key not in _flash_attn_fwd.compile_cache: + ( + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, + learnable_sink_tensor, + ) = [ + to_cute_tensor(t, assumed_align=4, leading_dim=0) + if t is not None + else None + for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink) + ] + page_table_tensor = ( + to_cute_tensor(page_table, assumed_align=4, leading_dim=1) + if page_table is not None + else None + ) + q_tensor, k_tensor, v_tensor, o_tensor = [ + to_cute_tensor(t) for t in (q, k, v, out if not is_split_kv else out_partial) + ] + if is_split_kv: + lse_tensor = to_cute_tensor(lse_partial, assumed_align=4) + elif lse is not None: + lse_tensor = to_cute_tensor(lse, assumed_align=4) + else: + lse_tensor = None + + sparse_tensors = None + if normalized_block_sparse_tensors is not None: + sparse_tensors = to_cute_block_sparse_tensors(normalized_block_sparse_tensors) + + cute_aux_tensors = None + if aux_tensors is not None: + cute_aux_tensors = [to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) for buf in aux_tensors] + + if compute_capability == 9: + assert page_table is None, "paged KV not supported on SM 9.0" + assert not is_split_kv, "SplitKV not supported on SM 9.0" + # fa_fwd = FlashAttentionForwardSm80( + fa_fwd = FlashAttentionForwardSm90( + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + is_causal=causal, + is_local=local, + pack_gqa=pack_gqa, + tile_m=m_block_size, + tile_n=n_block_size, + # num_stages=1, + num_stages=2, + num_threads=num_threads, + Q_in_regs=False, + intra_wg_overlap=True, + mma_pv_is_rs=True, + mask_mod=mask_mod, + score_mod=score_mod, + has_aux_tensors=aux_tensors is not None, + ) + elif compute_capability in [10, 11]: + fa_fwd = FlashAttentionForwardSm100( + head_dim, + head_dim_v, + qhead_per_kvhead=qhead_per_kvhead, + is_causal=causal, + is_local=local, + is_split_kv=is_split_kv, + pack_gqa=pack_gqa, + m_block_size=m_block_size, + n_block_size=n_block_size, + q_stage=q_stage, + is_persistent=not causal + and not local + and cu_seqlens_q is None + and seqused_q is None + and not is_split_kv, + score_mod=score_mod, + mask_mod=mask_mod, + has_aux_tensors=aux_tensors is not None, + paged_kv_non_tma=page_size not in [None, 128], + is_varlen_q=cu_seqlens_q is not None + or seqused_q is not None, + ) + else: + raise ValueError( + f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x, 11.x" + ) + # TODO: check @can_implement + _flash_attn_fwd.compile_cache[compile_key] = cute.compile( + fa_fwd, + q_tensor, + k_tensor, + v_tensor, + o_tensor, + lse_tensor, + softmax_scale, + current_stream, + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, + page_table_tensor, + window_size_left, + window_size_right, + learnable_sink_tensor, + sparse_tensors, + cute_aux_tensors, + options="--enable-tvm-ffi", + ) + + _flash_attn_fwd.compile_cache[compile_key]( + q, + k, + v, + out if not is_split_kv else out_partial, + lse_partial if is_split_kv else lse, + softmax_scale, + current_stream, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + page_table, + window_size_left, + window_size_right, + learnable_sink, + normalized_block_sparse_tensors, + aux_tensors, + ) + if is_split_kv: + _flash_attn_fwd_combine( + out_partial, + lse_partial.transpose(-1, -2), + out, + lse.transpose(-1, -2) if lse is not None else None, + cu_seqlens_q, + seqused_q, + ) + return out, lse + + +_flash_attn_fwd.compile_cache = {} + + +def _flash_attn_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + dout: torch.Tensor, + lse: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, + softcap: float = 0.0, + window_size_left: Optional[int] = None, + window_size_right: Optional[int] = None, + m_block_size: int = 64, + n_block_size: int = 128, + num_threads: int = 256, + pack_gqa: bool = False, + num_stages_Q: int = 2, + num_stages_dO: int = 2, + SdP_swapAB: bool = False, + dKV_swapAB: bool = False, + dQ_swapAB: bool = False, + AtomLayoutMSdP: int = 2, + AtomLayoutNdKV: int = 2, + AtomLayoutMdQ: int = 2, + V_in_regs: bool = False, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + deterministic: bool = False, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + score_mod: Optional[Callable] = None, + score_mod_bwd: Optional[Callable] = None, + mask_mod: Optional[Callable] = None, + aux_tensors: Optional[list[torch.Tensor]] = None, + block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + compute_capability = _get_device_capability() + assert compute_capability in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x" + + if compute_capability == 9: + m_block_size = 80 if not causal else 64 + n_block_size = 128 + num_stages_Q = 2 + num_stages_dO = 2 + num_stages_PdS = 2 + SdP_swapAB = True + dKV_swapAB = False + dQ_swapAB = not causal + AtomLayoutMSdP = 1 + AtomLayoutNdKV = 2 + AtomLayoutMdQ = 1 + cluster_size = 1 + assert window_size_left is None and window_size_right is None, "local not supported yet on 9.x" + else: + m_block_size = 128 + n_block_size = 128 + dQ_swapAB = False + dKV_swapAB = False + AtomLayoutMdQ = 1 + AtomLayoutNdKV = 1 + # TODO: support cluster size 2 + cluster_size = 1 + q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [ + maybe_contiguous(t) + for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) + ] + num_head, head_dim = q.shape[-2:] + if cu_seqlens_q is None: + batch_size, seqlen_q = q.shape[:2] + total_q = batch_size * seqlen_q + else: + batch_size = cu_seqlens_q.shape[0] - 1 + total_q = q.shape[0] + seqlen_q = max_seqlen_q if max_seqlen_q is not None else total_q + + if cu_seqlens_k is None: + batch_size, seqlen_k = k.shape[:2] + total_k = batch_size * seqlen_k + else: + batch_size = cu_seqlens_k.shape[0] - 1 + total_k = k.shape[0] + seqlen_k = max_seqlen_k if max_seqlen_k is not None else total_k + + num_head_kv = k.shape[-2] + head_dim_v = v.shape[-1] + + if causal: + window_size_right = 0 + local = window_size_left is not None or window_size_right is not None + if local: + if window_size_left is None and window_size_right == 0: + causal, local = True, False + window_size_right = None + else: + causal, local = False, True + + use_block_sparsity = block_sparse_tensors is not None + + # SM90 block-sparse backward: tile_m=64 is the GCD between a m_block_size that fits, + # the base block_m of 128 from forward, and block-sparse size for subtiling. + if compute_capability == 9 and use_block_sparsity: + m_block_size = 64 + # dQ_swapAB tuning: use False when m_block_size=64 (same as causal case) + dQ_swapAB = False + + # NB: this could be derived from the block_sparse_tensors but for now we hardcode it to 2 + subtile_factor = 2 + sparse_block_size_q = subtile_factor * m_block_size + + seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size + seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size + + if cu_seqlens_k is None: + assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) + assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) + else: + assert k.shape == (total_k, num_head_kv, head_dim) + assert v.shape == (total_k, num_head_kv, head_dim_v) + assert cu_seqlens_k.shape == (batch_size + 1,), ( + "cu_seqlens_k must have shape (batch_size + 1,)" + ) + + if cu_seqlens_q is not None: + assert cu_seqlens_q.shape == (batch_size + 1,), ( + "cu_seqlens_q must have shape (batch_size + 1,)" + ) + + assert out.shape == (total_q, num_head, head_dim_v) + assert dout.shape == (total_q, num_head, head_dim_v) + assert lse.shape == (num_head, total_q), "lse must have shape (num_head, total_q)" + else: + assert out.shape == (batch_size, seqlen_q, num_head, head_dim_v) + assert dout.shape == (batch_size, seqlen_q, num_head, head_dim_v) + assert lse.shape == (batch_size, num_head, seqlen_q), ( + "lse must have shape (batch_size, num_head, seqlen_q)" + ) + + assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" + assert q.dtype == k.dtype == v.dtype == out.dtype == dout.dtype, ( + "inputs must have the same dtype" + ) + for t in [cu_seqlens_q, cu_seqlens_k]: + if t is not None: + assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k must be int32" + assert lse.dtype == torch.float32, "lse must be float32" + assert all( + t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k) + ), "inputs must be on CUDA device" + assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" + assert head_dim <= 256, "head_dim must be less than or equal to 256" + alignment = 16 // q.element_size() + assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}" + assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}" + if softmax_scale is None: + softmax_scale = 1.0 / math.sqrt(head_dim) + qhead_per_kvhead = num_head // num_head_kv + if pack_gqa is None: + pack_gqa = qhead_per_kvhead > 1 + # pack_gqa backward not yet supported in bwd + pack_gqa = False + if compute_capability not in [10, 11]: + assert deterministic is False, "bwd deterministic only supported for sm100/sm110 for now" + + if score_mod is not None: + assert score_mod_bwd is not None, "score_mod_bwd is required when score_mod is provided" + assert softcap == 0.0, "softcap and score_mod are mutually exclusive (different log2 scaling)" + assert cu_seqlens_q is None and cu_seqlens_k is None, ( + "varlen + score_mod not supported in bwd yet" + ) + + device = q.device + out_torch_dtype = q.dtype + + if dq is None: + dq = torch.empty_like(q) + else: + _validate_tensor(dq, "dq", q.shape, out_torch_dtype, device) + + if dk is None: + dk = torch.empty_like(k) + else: + _validate_tensor(dk, "dk", k.shape, out_torch_dtype, device) + + if dv is None: + dv = torch.empty_like(v) + else: + _validate_tensor(dv, "dv", v.shape, out_torch_dtype, device) + + head_dim_rounded = (head_dim + 32 - 1) // 32 * 32 + + if cu_seqlens_q is None: + dq_accum = torch.empty( + batch_size, + num_head, + seqlen_q_rounded * head_dim_rounded, + dtype=torch.float32, + device=device, + ) + dpsum = torch.empty( + batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device + ) + lse_log2 = torch.empty( + batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device + ) + else: + total_q_rounded_padded = ( + (total_q + cu_seqlens_q.shape[0] * m_block_size - 1) // m_block_size * m_block_size + ) + dq_accum = torch.empty( + num_head, total_q_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device + ) + dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) + lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) + + dKV_postprocess = qhead_per_kvhead > 1 + if dKV_postprocess: + head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32 + if cu_seqlens_k is None: + num_n_blocks = seqlen_k_rounded // n_block_size + if cluster_size == 2 and num_n_blocks % cluster_size != 0: + seqlen_k_rounded = seqlen_k_rounded + n_block_size + dk_accum = torch.zeros( + batch_size, + num_head_kv, + seqlen_k_rounded * head_dim_rounded, + dtype=torch.float32, + device=device, + ) + dv_accum = torch.zeros( + batch_size, + num_head_kv, + seqlen_k_rounded * head_dim_v_rounded, + dtype=torch.float32, + device=device, + ) + else: + total_k_rounded_padded = ( + (total_k + cu_seqlens_k.shape[0] * n_block_size - 1) // n_block_size * n_block_size + ) + num_n_blocks = total_k_rounded_padded // n_block_size + if cluster_size == 2 and num_n_blocks % cluster_size != 0: + total_k_rounded_padded = total_k_rounded_padded + n_block_size + dk_accum = torch.zeros( + num_head_kv, + total_k_rounded_padded * head_dim_rounded, + dtype=torch.float32, + device=device, + ) + dv_accum = torch.zeros( + num_head_kv, + total_k_rounded_padded * head_dim_v_rounded, + dtype=torch.float32, + device=device, + ) + + dtype = torch2cute_dtype_map[q.dtype] + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + if deterministic: + dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, 1, dtype=torch.int32, device="cuda") + else: + dQ_semaphore = None + + if deterministic and qhead_per_kvhead > 1: + dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda") + dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device="cuda") + else: + dK_semaphore = None + dV_semaphore = None + + # Preprocess kernel: compute (o * dout).sum(dim=-1), lse * log2_e, and zero out dq_accum. + compile_key_pre = ( + compute_capability, + dtype, + head_dim_v, + m_block_size, + num_threads, + cu_seqlens_q is None, + seqused_q is None, + ) + if compile_key_pre not in _flash_attn_bwd.compile_cache_pre: + o_tensor, do_tensor = [to_cute_tensor(t) for t in (out, dout)] + dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [ + to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2) + ] + lse_tensor = to_cute_tensor(lse, assumed_align=4) + cu_seqlens_q_tensor, seqused_q_tensor = [ + to_cute_tensor(t, assumed_align=4) if t is not None else None + for t in (cu_seqlens_q, seqused_q) + ] + arch = compute_capability * 10 + fa_bwd_pre = FlashAttentionBackwardPreprocess( + dtype, + head_dim_v, + arch, + m_block_size, + num_threads=num_threads, + ) + # TODO: check @can_implement + _flash_attn_bwd.compile_cache_pre[compile_key_pre] = cute.compile( + fa_bwd_pre, + o_tensor, + do_tensor, + dpsum_tensor, + lse_tensor, + lse_log2_tensor, + dq_accum_tensor, + cu_seqlens_q_tensor, + seqused_q_tensor, + current_stream, + options="--enable-tvm-ffi", + ) + _flash_attn_bwd.compile_cache_pre[compile_key_pre]( + out, + dout, + dpsum, + lse, + lse_log2, + dq_accum, + cu_seqlens_q, + seqused_q, + current_stream, + ) + + # NB num_threads application for 3 kernels + # There are pre, main, post processing kernels, currenlty num_threads is only actually + # used for the pre proc, and then we hard code to 384 for the main and post proc, and we do + # before cache key gen + num_threads = 384 + + # Backward kernel: compute dk, dv, dq_accum. + score_mod_hash = utils.hash_callable(score_mod) if score_mod else False + score_mod_bwd_hash = utils.hash_callable(score_mod_bwd) if score_mod_bwd else False + mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod else False + num_aux_tensors = len(aux_tensors) if aux_tensors else 0 + cute_aux_tensors = None + if aux_tensors is not None: + cute_aux_tensors = [to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) for buf in aux_tensors] + + block_sparse_broadcast_pattern = None + normalized_block_sparse_tensors = None + if block_sparse_tensors is not None: + expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd( + batch_size, num_head, seqlen_q, seqlen_k, + m_block_size, n_block_size, subtile_factor, + ) + normalized_block_sparse_tensors = normalize_block_sparse_tensors( + block_sparse_tensors, + expected_count_shape=expected_count_shape, + expected_index_shape=expected_index_shape, + context="_flash_attn_bwd", + hint=lambda: ( + f"Backward expects Q-direction block-sparse tensors (q_mask_cnt/q_mask_idx, and optionally full_q_cnt/full_q_idx). " + f"Regenerate the backward BlockMask with BLOCK_SIZE=({sparse_block_size_q}, {n_block_size}) " + f"(sparse_block_size_q={sparse_block_size_q})." + ), + ) + block_sparse_broadcast_pattern = get_block_sparse_broadcast_pattern( + normalized_block_sparse_tensors + ) + + if compute_capability == 9: + compile_key = ( + compute_capability, + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + causal, + softcap != 0.0, + m_block_size, + n_block_size, + num_threads, + pack_gqa, + num_stages_Q, + num_stages_dO, + SdP_swapAB, + dKV_swapAB, + dQ_swapAB, + AtomLayoutMSdP, + AtomLayoutNdKV, + AtomLayoutMdQ, + V_in_regs, + cu_seqlens_q is None, + cu_seqlens_k is None, + seqused_q is None, + seqused_k is None, + score_mod_hash, + score_mod_bwd_hash, + mask_mod_hash, + num_aux_tensors, + use_block_sparsity, + block_sparse_broadcast_pattern, + ) + else: + compile_key = ( + compute_capability, + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + causal, + window_size_left is not None, + window_size_right is not None, + softcap != 0.0, + m_block_size, + n_block_size, + num_threads, + pack_gqa, + cluster_size, + deterministic, + score_mod_hash, + score_mod_bwd_hash, + mask_mod_hash, + num_aux_tensors, + use_block_sparsity, + block_sparse_broadcast_pattern, + cu_seqlens_q is None, + cu_seqlens_k is None, + seqused_q is None, + seqused_k is None, + ) + if compile_key not in _flash_attn_bwd.compile_cache: + q_tensor, k_tensor, v_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [ + to_cute_tensor(t) for t in (q, k, v, dout, dq, dk, dv) + ] + dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [ + to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2) + ] + if dKV_postprocess: + dk_accum_tensor, dv_accum_tensor = [ + to_cute_tensor(t) for t in (dk_accum, dv_accum) + ] + cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [ + to_cute_tensor(t, assumed_align=4) if t is not None else None + for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) + ] + dQ_semaphore_tensor, dK_semaphore_tensor, dV_semaphore_tensor = [ + utils.convert_from_dlpack_leading_static(t.detach(), leading_dim=3, alignment=4, stride_order=t.dim_order()) + if t is not None else None + for t in (dQ_semaphore, dK_semaphore, dV_semaphore) + ] + fa_bwd_sm80 = FlashAttentionBackwardSm80( + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + m_block_size, + n_block_size, + num_stages_Q, + num_stages_dO, + num_threads, + pack_gqa, + causal, + SdP_swapAB, + dKV_swapAB, + dQ_swapAB, + AtomLayoutMSdP, + AtomLayoutNdKV, + AtomLayoutMdQ, + V_in_regs=V_in_regs, + ) + if compute_capability == 9: + fa_bwd_obj = FlashAttentionBackwardSm90( + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + causal, + m_block_size, + n_block_size, + num_stages_Q, + num_stages_dO, + num_stages_PdS, + SdP_swapAB, + dKV_swapAB, + dQ_swapAB, + AtomLayoutMSdP, + AtomLayoutNdKV, + AtomLayoutMdQ, + num_threads, + V_in_regs=V_in_regs, + score_mod=score_mod, + score_mod_bwd=score_mod_bwd, + mask_mod=mask_mod, + has_aux_tensors=aux_tensors is not None, + subtile_factor=subtile_factor, + ) + else: + fa_bwd_obj = FlashAttentionBackwardSm100( + head_dim, + head_dim_v, + is_causal=causal, + is_local=local, + qhead_per_kvhead=qhead_per_kvhead, + # tile_m=m_block_size, + # tile_n=n_block_size, + cluster_size=cluster_size, + # cluster_size=1, + deterministic=deterministic, + score_mod=score_mod, + score_mod_bwd=score_mod_bwd, + mask_mod=mask_mod, + has_aux_tensors=aux_tensors is not None, + subtile_factor=subtile_factor, + ) + + # Block sparse tensors for backward use Q-direction indexing (transposed from forward). + # sparse_block_size_q = subtile_factor * tile_m matches BlockMask granularity. + sparse_tensors_compile = None + if normalized_block_sparse_tensors is not None: + sparse_tensors_compile = to_cute_block_sparse_tensors(normalized_block_sparse_tensors) + + # TODO: check @can_implement + _flash_attn_bwd.compile_cache[compile_key] = cute.compile( + fa_bwd_obj, + q_tensor, + k_tensor, + v_tensor, + do_tensor, + lse_log2_tensor, + dpsum_tensor, + dq_accum_tensor, + dk_tensor if not dKV_postprocess else dk_accum_tensor, + dv_tensor if not dKV_postprocess else dv_accum_tensor, + softmax_scale, + current_stream, + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, + None, # softcap - not yet supported in backward + window_size_left, + window_size_right, + dQ_semaphore_tensor, + dK_semaphore_tensor, + dV_semaphore_tensor, + cute_aux_tensors, + sparse_tensors_compile, + options="--enable-tvm-ffi", + ) + _flash_attn_bwd.compile_cache[compile_key]( + q, + k, + v, + dout, + lse_log2, + dpsum, + dq_accum, + dk if not dKV_postprocess else dk_accum, + dv if not dKV_postprocess else dv_accum, + softmax_scale, + current_stream, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + None, # softcap - not yet supported in backward + window_size_left, + window_size_right, + dQ_semaphore, + dK_semaphore, + dV_semaphore, + aux_tensors, + normalized_block_sparse_tensors, + ) + + num_threads = 256 if compute_capability == 9 else 128 + arch = compute_capability * 10 + # Postprocess kernel: convert dq_accum from float32 to dq in bf16/fp16 + compile_key_post = ( + compute_capability, + dtype, + head_dim, + m_block_size, + num_threads, + AtomLayoutMdQ, + dQ_swapAB, + cu_seqlens_q is None, + seqused_q is None, + ) + if compile_key_post not in _flash_attn_bwd.compile_cache_post: + dq_accum_tensor = to_cute_tensor(dq_accum) + dq_tensor = to_cute_tensor(dq) + cu_seqlens_q_tensor, seqused_q_tensor = [ + to_cute_tensor(t, assumed_align=4) if t is not None else None + for t in (cu_seqlens_q, seqused_q) + ] + fa_bwd_post = FlashAttentionBackwardPostprocess( + dtype, head_dim, arch, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB + ) + # TODO: check @can_implement + _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( + fa_bwd_post, + dq_accum_tensor, + dq_tensor, + softmax_scale, + cu_seqlens_q_tensor, + seqused_q_tensor, + current_stream, + options="--enable-tvm-ffi", + ) + _flash_attn_bwd.compile_cache_post[compile_key_post]( + dq_accum, + dq, + softmax_scale, + cu_seqlens_q, + seqused_q, + current_stream, + ) + + if dKV_postprocess: + # Postprocess kernel: convert dk_accum & dv_accum from float32 to bf16/fp16 + compile_key_post = ( + compute_capability, + dtype, + head_dim, + n_block_size, + num_threads, + AtomLayoutNdKV, + dKV_swapAB, + cu_seqlens_k is None, + seqused_k is None, + ) + if compile_key_post not in _flash_attn_bwd.compile_cache_post: + dk_accum_tensor = to_cute_tensor(dk_accum) + dk_tensor = to_cute_tensor(dk) + cu_seqlens_k_tensor, seqused_k_tensor = [ + to_cute_tensor(t, assumed_align=4) if t is not None else None + for t in (cu_seqlens_k, seqused_k) + ] + arch = compute_capability * 10 + fa_bwd_post = FlashAttentionBackwardPostprocess( + dtype, head_dim, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB + ) + # TODO: check @can_implement + _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( + fa_bwd_post, + dk_accum_tensor, + dk_tensor, + softmax_scale, + cu_seqlens_k_tensor, + seqused_k_tensor, + current_stream, + options="--enable-tvm-ffi", + ) + _flash_attn_bwd.compile_cache_post[compile_key_post]( + dk_accum, + dk, + softmax_scale, + cu_seqlens_k, + seqused_k, + current_stream, + ) + compile_key_post = ( + compute_capability, + dtype, + head_dim_v, + n_block_size, + num_threads, + AtomLayoutNdKV, + dKV_swapAB, + cu_seqlens_k is None, + seqused_k is None, + ) + if compile_key_post not in _flash_attn_bwd.compile_cache_post: + dv_accum_tensor = to_cute_tensor(dv_accum) + dv_tensor = to_cute_tensor(dv) + cu_seqlens_k_tensor, seqused_k_tensor = [ + to_cute_tensor(t, assumed_align=4) if t is not None else None + for t in (cu_seqlens_k, seqused_k) + ] + arch = compute_capability * 10 + fa_bwd_post = FlashAttentionBackwardPostprocess( + dtype, head_dim_v, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB + ) + # TODO: check @can_implement + _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( + fa_bwd_post, + dv_accum_tensor, + dv_tensor, + cutlass.Float32(1.0), + cu_seqlens_k_tensor, + seqused_k_tensor, + current_stream, + options="--enable-tvm-ffi", + ) + _flash_attn_bwd.compile_cache_post[compile_key_post]( + dv_accum, + dv, + 1.0, + cu_seqlens_k, + seqused_k, + current_stream, + ) + + return dq, dk, dv + + +_flash_attn_bwd.compile_cache_pre = {} +_flash_attn_bwd.compile_cache = {} +_flash_attn_bwd.compile_cache_post = {} + + +class FlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size: Tuple[Optional[int], Optional[int]] = (None, None), + learnable_sink: Optional[torch.Tensor] = None, + softcap: float = 0.0, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + deterministic: bool = False, + mask_mod: Optional[Callable] = None, + full_block_cnt: Optional[torch.Tensor] = None, + full_block_idx: Optional[torch.Tensor] = None, + mask_block_cnt: Optional[torch.Tensor] = None, + mask_block_idx: Optional[torch.Tensor] = None, + ): + # Only create block sparse tensors if at least one block sparse parameter is provided + block_sparse_tensors = None + if any(t is not None for t in [full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx]): + block_sparse_tensors = BlockSparseTensorsTorch( + full_block_cnt=full_block_cnt, + full_block_idx=full_block_idx, + mask_block_cnt=mask_block_cnt, + mask_block_idx=mask_block_idx, + ) + out, lse = _flash_attn_fwd( + q, + k, + v, + softmax_scale=softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + learnable_sink=learnable_sink, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + mask_mod=mask_mod, + block_sparse_tensors=block_sparse_tensors + ) + ctx.save_for_backward(q, k, v, out, lse) + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.deterministic = deterministic + return out, lse + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, lse = ctx.saved_tensors + dq, dk, dv = _flash_attn_bwd( + q, + k, + v, + out, + dout, + lse, + ctx.softmax_scale, + ctx.causal, + ctx.softcap, + window_size_left=ctx.window_size[0], + window_size_right=ctx.window_size[1], + deterministic=ctx.deterministic, + ) + return dq, dk, dv, *((None,) * 20) # Extra Nones is fine + + +class FlashAttnVarlenFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + page_table: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size: Tuple[Optional[int], Optional[int]] = (None, None), + learnable_sink: Optional[torch.Tensor] = None, + softcap: float = 0.0, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + deterministic: bool = False, + score_mod: Optional[Callable] = None, + aux_tensors: Optional[list] = None, + ): + out, lse = _flash_attn_fwd( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + page_table=page_table, + softmax_scale=softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + learnable_sink=learnable_sink, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + score_mod=score_mod, + aux_tensors=aux_tensors, + ) + ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.deterministic = deterministic + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + return out, lse + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors + assert ctx.softcap == 0.0 + dq, dk, dv = _flash_attn_bwd( + q, + k, + v, + out, + dout, + lse, + ctx.softmax_scale, + ctx.causal, + ctx.softcap, + window_size_left=ctx.window_size[0], + window_size_right=ctx.window_size[1], + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_k=ctx.max_seqlen_k, + deterministic=ctx.deterministic, + ) + + return dq, dk, dv, *((None,) * 20) + + +def flash_attn_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size: Tuple[Optional[int], Optional[int]] = (None, None), + learnable_sink: Optional[torch.Tensor] = None, + softcap: float = 0.0, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + deterministic: bool = False, + mask_mod: Optional[Callable] = None, + full_block_cnt: Optional[torch.Tensor] = None, + full_block_idx: Optional[torch.Tensor] = None, + mask_block_cnt: Optional[torch.Tensor] = None, + mask_block_idx: Optional[torch.Tensor] = None, +): + return FlashAttnFunc.apply( + q, + k, + v, + softmax_scale, + causal, + window_size, + learnable_sink, + softcap, + num_splits, + pack_gqa, + deterministic, + mask_mod, + full_block_cnt, + full_block_idx, + mask_block_cnt, + mask_block_idx, + ) + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size: Tuple[Optional[int], Optional[int]] = (None, None), + learnable_sink: Optional[torch.Tensor] = None, + softcap: float = 0.0, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + deterministic: bool = False, + score_mod: Optional[Callable] = None, + aux_tensors: Optional[list] = None, +): + return FlashAttnVarlenFunc.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + page_table, + softmax_scale, + causal, + window_size, + learnable_sink, + softcap, + num_splits, + pack_gqa, + deterministic, + score_mod, + aux_tensors, + ) + + +def _flash_attn_fwd_combine( + out_partial: torch.Tensor, + lse_partial: torch.Tensor, + out: torch.Tensor, + lse: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + seqused: Optional[torch.Tensor] = None, + num_splits_dynamic_ptr: Optional[torch.Tensor] = None, + semaphore_to_reset: Optional[torch.Tensor] = None, +) -> None: + """Forward combine kernel for split attention computation. + + Combines partial outputs and log-sum-exp values from multiple splits + of attention computation into final outputs. + + Args: + out_partial: Partial outputs tensor (num_splits, batch, seqlen, nheads, headdim) or + (num_splits, total_q, nheads, headdim) if there's cu_seqlens + lse_partial: Partial LSE tensor (num_splits, batch, seqlen, nheads) or + (num_splits, total_q, nheads) if there's cu_seqlens + out: Output tensor (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim) if there's cu_seqlens + lse: Output LSE tensor (batch, seqlen, nheads) or (total_q, nheads) if there's cu_seqlens. + cu_seqlens: Cumulative sequence lengths for variable length sequences + seqused: Used sequence lengths for each batch + num_splits_dynamic_ptr: Dynamic number of splits per batch + semaphore_to_reset: Semaphore for synchronization + k_block_size: Block size for head dimension + + Returns: + None + """ + # Input validation + assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions" + assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions" + assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], ( + "out_partial must be fp16, bf16, or fp32" + ) + assert lse_partial.dtype == torch.float32, "lse_partial must be fp32" + assert out_partial.is_cuda and lse_partial.is_cuda, "tensors must be on CUDA device" + assert out_partial.stride(-1) == 1, "out_partial must be contiguous in the last dimension" + assert lse_partial.stride(-2) == 1, "lse_partial must be contiguous in the seqlen dimension" + assert lse_partial.shape == out_partial.shape[:-1] + + # Determine if this is variable length based on dimensions + is_varlen = out_partial.dim() == 4 + + # Validate output tensor shapes and types + assert out.shape == out_partial.shape[1:], "out shape mismatch" + if lse is not None: + assert lse.shape == lse_partial.shape[1:], "lse shape mismatch" + assert lse.dtype == torch.float32, "lse must be fp32" + + # Validate optional tensors + for t, name in [ + (cu_seqlens, "cu_seqlens"), + (seqused, "seqused"), + (num_splits_dynamic_ptr, "num_splits_dynamic_ptr"), + ]: + if t is not None: + assert t.dtype == torch.int32, f"{name} must be int32" + assert t.is_cuda, f"{name} must be on CUDA device" + assert t.is_contiguous(), f"{name} must be contiguous" + + head_dim = out_partial.shape[-1] + num_splits = out_partial.shape[0] + assert num_splits <= 256 + # If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively + # so that kBlockM is smaller and we have more parallelism. + k_block_size = 64 if head_dim <= 64 else 128 + # We want kBlockM to be as small as possible to maximize parallelism. + # E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats). + m_block_size = 8 if k_block_size % 128 == 0 else (16 if k_block_size % 64 == 0 else 32) + log_max_splits = max(math.ceil(math.log2(num_splits)), 4) + if m_block_size == 8: + # If kBlockM == 8 then the minimum number of splits is 32. + # TODO: we can deal w this by using 128 threads instead + log_max_splits = max(log_max_splits, 5) + + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + # Create combine kernel configuration + dtype = torch2cute_dtype_map[out.dtype] + dtype_partial = torch2cute_dtype_map[out_partial.dtype] + + compile_key = ( + dtype, + dtype_partial, + head_dim, + m_block_size, + k_block_size, + log_max_splits, + cu_seqlens is not None, + seqused is not None, + lse is not None, + ) + + if compile_key not in _flash_attn_fwd_combine.compile_cache: + out_partial_tensor = to_cute_tensor( + out_partial, leading_dim=4 if not is_varlen else 3 + ) + lse_partial_tensor = to_cute_tensor( + lse_partial, assumed_align=4, leading_dim=lse_partial.ndim - 2 + ) + out_tensor = to_cute_tensor(out, leading_dim=3 if not is_varlen else 2) + lse_tensor = ( + to_cute_tensor(lse, assumed_align=4, leading_dim=lse.ndim - 2) + if lse is not None + else None + ) + + optional_tensors = [ + to_cute_tensor(t, assumed_align=4, leading_dim=0) + if t is not None + else None + for t in (cu_seqlens, seqused, num_splits_dynamic_ptr, semaphore_to_reset) + ] + cu_seqlens_tensor, seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor = ( + optional_tensors + ) + fa_combine = FlashAttentionForwardCombine( + dtype=dtype, + dtype_partial=dtype_partial, + head_dim=head_dim, + m_block_size=m_block_size, + k_block_size=k_block_size, + log_max_splits=log_max_splits, + ) + + # Check if implementation is supported + if not fa_combine.can_implement( + dtype, + dtype_partial, + head_dim, + m_block_size, + k_block_size, + log_max_splits, + num_threads=256, + ): + raise RuntimeError( + "FlashAttention combine kernel cannot be implemented with given parameters" + ) + + _flash_attn_fwd_combine.compile_cache[compile_key] = cute.compile( + fa_combine, + out_partial_tensor, + lse_partial_tensor, + out_tensor, + lse_tensor, + cu_seqlens_tensor, + seqused_tensor, + num_splits_dynamic_tensor, + semaphore_tensor, + current_stream, + options="--enable-tvm-ffi", + ) + _flash_attn_fwd_combine.compile_cache[compile_key]( + out_partial, + lse_partial, + out, + lse, + cu_seqlens, + seqused, + num_splits_dynamic_ptr, + semaphore_to_reset, + current_stream, + ) + + +_flash_attn_fwd_combine.compile_cache = {} + + +def flash_attn_combine( + out_partial: torch.Tensor, + lse_partial: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + cu_seqlens: Optional[torch.Tensor] = None, + seqused: Optional[torch.Tensor] = None, + return_lse: bool = True, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Flash Attention combine function for split attention computation. + + Combines partial outputs and log-sum-exp values from multiple splits + of attention computation into final outputs. This is the main user-facing + interface for the combine kernel. + + Args: + out_partial: Partial outputs tensor with shape: + - (num_splits, batch_size, seqlen, num_heads, head_size) for regular batched input + - (num_splits, total_q, num_heads, head_size) for variable length input + lse_partial: Partial LSE tensor with shape: + - (num_splits, batch_size, seqlen, num_heads) for regular batched input + - (num_splits, total_q, num_heads) for variable length input + out: Optional output tensor. If None, will be created automatically. + out_dtype: Optional output dtype. If None, will use fp16/bf16 based on input. + cu_seqlens: Cumulative sequence lengths for variable length sequences + seqused: Used sequence lengths for each batch + return_lse: Whether to return the combined LSE tensor. Default is True. + + Returns: + Tuple of (out, lse) where: + - out: Combined output tensor with shape (batch_size, seqlen, num_heads, head_size) + or (total_q, num_heads, head_size) for varlen + - lse: Combined log-sum-exp tensor with shape (batch_size, seqlen, num_heads) + or (total_q, num_heads) for varlen. None if return_lse=False + + Note: + This function expects the input tensors to be in the format produced by + split attention computation, where the first dimension is num_splits. + The permuting from user format to kernel format is now done inside the kernel. + """ + # Input validation + assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions" + assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions" + assert out_partial.dtype == torch.float32, "out_partial must be fp32 (from accumulation)" + assert lse_partial.dtype == torch.float32, "lse_partial must be fp32" + + # Determine if this is variable length based on dimensions + is_varlen = out_partial.dim() == 4 + + if is_varlen: + # Variable length: (num_splits, total_q, num_heads, head_size) + num_splits, total_q, num_heads, head_size = out_partial.shape + assert lse_partial.shape == (num_splits, total_q, num_heads), ( + "lse_partial shape mismatch for varlen" + ) + batch_size = 1 # Treat as single batch for varlen + seqlen = total_q + else: + # Regular batched: (num_splits, batch_size, seqlen, num_heads, head_size) + num_splits, batch_size, seqlen, num_heads, head_size = out_partial.shape + assert lse_partial.shape == (num_splits, batch_size, seqlen, num_heads), ( + "lse_partial shape mismatch" + ) + + # Determine output dtype + if out_dtype is None: + out_dtype = out_partial.dtype + + # Create output if not provided + device = out_partial.device + if out is None: + if is_varlen: + out = torch.empty(total_q, num_heads, head_size, dtype=out_dtype, device=device) + else: + out = torch.empty( + batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device + ) + + # Create lse output only if requested + if return_lse: + if is_varlen: + lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device).transpose( + 0, 1 + ) + else: + lse = torch.empty( + batch_size, num_heads, seqlen, dtype=torch.float32, device=device + ).transpose(1, 2) + else: + lse = None + + _flash_attn_fwd_combine( + out_partial, + lse_partial, + out, + lse, + cu_seqlens, + seqused, + ) + return out, lse diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py new file mode 100644 index 00000000000..c0ba457b129 --- /dev/null +++ b/flash_attn/cute/mask.py @@ -0,0 +1,651 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Optional, Callable +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr + +import flash_attn.cute.utils as utils +from flash_attn.cute.seqlen_info import SeqlenInfoQK + + +@cute.jit +def mask_r2p(X: cute.Tensor, col_limit: Int32, arch: int = 90, rank1: bool = False) -> None: + # Bit manipulation, compiles down to the R2P instruction + # For sm100: we know that tScS_t2r[i][1] == i, for the particular tmem copy atom we're using. + # For sm90: instead of comparing limit to 0, 1, 8, 9, 16, 17, ..., + # we compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ... + if const_expr(arch == 90): + col_limit_transformed = col_limit // 8 * 2 + min(col_limit % 8, 2) + else: + col_limit_transformed = col_limit + ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape)) + # Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31 + for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): + # Don't need to clamp to 32 since the shr.u32 instruction does that already + col_limit_right_s = max(col_limit_transformed - s * 24, 0) + # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 + mask = (1 << col_limit_right_s) - 1 + # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction + for i in cutlass.range_constexpr(min(24, ncol - s * 24)): + in_bound = cutlass.Boolean(mask & (1 << i)) + c = s * 24 + i + if const_expr(rank1): + X[c] = X[c] if in_bound else -Float32.inf + # This is the equivalent of: + # X[s * 24 + i] = X[s * 24 + i] if col_limit_right_s <= i else -Float32.inf + else: + for r in cutlass.range_constexpr(cute.size(X.shape[0])): + X[r, c] = X[r, c] if in_bound else -Float32.inf + + +@cute.jit +def mask_r2p_transposed(X: cute.Tensor, row_limit_top: Int32, num_rep: int) -> None: + # Bit manipulation, compiles down to the R2P instruction + # For sm100: we know that tScS_t2r[i][0] has the form 0, 1, ..., 31, 64, ..., 127 + # or 0, 1, ..., 15, 32, ..., 47, 64, ... + # We compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ... + # Here we hardcode for the case of 2 warp groups. + num_wg = 2 + row_limit_top_transformed = row_limit_top // (num_rep * num_wg) * num_rep + min( + row_limit_top % (num_rep * num_wg), num_rep + ) + ncol = cute.size(X.shape) + # Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31 + for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): + row_limit_top_s = max(row_limit_top_transformed - s * 24, 0) + # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 + mask = (1 << row_limit_top_s) - 1 + # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction + for i in cutlass.range_constexpr(min(24, ncol - s * 24)): + out_bound = cutlass.Boolean(mask & (1 << i)) + c = s * 24 + i + X[c] = -Float32.inf if out_bound else X[c] + # tidx = cute.arch.thread_idx()[0] % 256 + # if tidx == 128: + # cute.printf("tidx = {}, s = {}, i = {}, row_limit_top = {}, row_limit_top_s = {}, mask = {}, out_bound = {}", tidx, s, i, row_limit_top, row_limit_top_s, mask, out_bound) + + +@cute.jit +def mask_r2p_dual_bound( + X: cute.Tensor, + col_limit_left: Int32, # Inclusive lower bound + col_limit_right: Int32, # Exclusive upper bound +) -> None: + """ + Dual-bound masking using two bitmasks for SM100, following mask_r2p. + Masks elements where: NOT (col_limit_left <= col < col_limit_right) + + Uses bit manipulation to create a range mask: + mask_right = (1 << right) - 1 -> bits (right-1)..0 are 1 + mask_left = (1 << left) - 1 -> bits (left-1)..0 are 1 + mask_range = mask_range = mask_right & ~ mask_left -> bits (right-1)..left are 1 + """ + ncol = const_expr(cute.size(X.shape)) + + for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): + right_s = max(col_limit_right - s * 24, 0) + left_s = max(col_limit_left - s * 24, 0) + + # otherwise cute dsl complains about python int too large to convert into c long + right_s = min(right_s, 24) + left_s = min(left_s, 24) + + # bits (right-1)..left are 1 + mask_right = (1 << right_s) - 1 + mask_left = (1 << left_s) - 1 + mask_range = mask_right & ~mask_left + + # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction + for i in cutlass.range_constexpr(min(24, ncol - s * 24)): + in_bound = cutlass.Boolean(mask_range & (1 << i)) + c = s * 24 + i + X[c] = X[c] if in_bound else -Float32.inf + + +@dataclass(frozen=True) +class AttentionMask: + tile_m: cutlass.Constexpr[int] + tile_n: cutlass.Constexpr[int] + seqlen_info: SeqlenInfoQK + window_size_left: Optional[Int32] = None + window_size_right: Optional[Int32] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 # only pass in if we're doing PackGQA + swap_AB: cutlass.Constexpr[bool] = False + + @property + def seqlen_q(self) -> Int32: + return self.seqlen_info.seqlen_q + + @property + def seqlen_k(self) -> Int32: + return self.seqlen_info.seqlen_k + + @cute.jit + def apply_mask( + self, + acc_S: cute.Tensor, + batch_idx: cutlass.Int32, + head_idx: cutlass.Int32, + m_block: cutlass.Int32, + n_block: cutlass.Int32, + thr_mma: cute.TiledMma, + mask_seqlen: cutlass.Constexpr[bool], + mask_causal: cutlass.Constexpr[bool], + mask_local: cutlass.Constexpr[bool] = False, + mask_mod: cutlass.Constexpr[Optional[Callable]] = None, + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), + ) -> None: + assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" + acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.swap_AB) + acc_shape = (self.tile_m, self.tile_n) + cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1]) + tScS_mn = utils.make_acc_tensor_mn_view(thr_mma.partition_C(cS), transpose=self.swap_AB) + # We use t0ScS as these indices are known at compile time. We then must subtract the + # column limit by the thread column offset. + t0ScS_mn = utils.make_acc_tensor_mn_view( + thr_mma.get_slice(0).partition_C(cS), transpose=self.swap_AB + ) + ROW = 0 if const_expr(not self.swap_AB) else 1 + COL = 1 if const_expr(not self.swap_AB) else 0 + thr_col_offset = tScS_mn[0][COL] + # To handle edge cases of completely masked out rows where n_block_max = 0, + # we treat negative n_blocks as 0th n_block + # TODO: find more transparent solution + if n_block < 0: + n_block = 0 + seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset + if const_expr(not mask_causal and not mask_local and mask_mod is None): + if const_expr(mask_seqlen): + # The compiler now choses not to use R2P + r2p = const_expr(False and not self.swap_AB) + if const_expr(not r2p): + # traverse column index. + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): + oob = t0ScS_mn[0, c][COL] >= seqlenk_col_limit + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): + acc_S_mn[r, c] = -Float32.inf if oob else acc_S_mn[r, c] + else: + mask_r2p(acc_S_mn, seqlenk_col_limit, arch=90) + + elif const_expr( + not mask_causal and not mask_local and mask_mod is not None + ): # FlexAttention mask mod + nrow = const_expr(cute.size(tScS_mn.shape[0])) + ncol = const_expr(cute.size(tScS_mn.shape[1])) + has_fastdiv = const_expr( + fastdiv_mods is not None + and fastdiv_mods[0] is not None + and fastdiv_mods[1] is not None + ) + wrap_aux_indices = const_expr( + has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None) + ) + + for r in cutlass.range_constexpr(nrow): + # Respect swap_AB: ROW/COL determine which coordinate component corresponds to Q/KV. + local_row = tScS_mn[r, 0][ROW] + global_row_idx = local_row + m_block * self.tile_m + row_for_mod = global_row_idx + head_idx_for_mod = head_idx + if const_expr(self.qhead_per_kvhead_packgqa != 1): + head_offset = global_row_idx % self.qhead_per_kvhead_packgqa + head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset + row_for_mod = global_row_idx // self.qhead_per_kvhead_packgqa + row_for_seqlen = row_for_mod + if const_expr(wrap_aux_indices): + _, row_for_mod = divmod(row_for_mod, fastdiv_mods[0]) + + for col in cutlass.range_constexpr(ncol): + col_idx_local = t0ScS_mn[0, col][COL] + # Convert to absolute column index + global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n + col_for_mod = global_col_idx + if const_expr(wrap_aux_indices): + _, col_for_mod = divmod(global_col_idx, fastdiv_mods[1]) + + batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) + head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32) + q_idx_ssa = utils.scalar_to_ssa(row_for_mod, cutlass.Int32) + kv_idx_ssa = utils.scalar_to_ssa(col_for_mod, cutlass.Int32) + mask_value = mask_mod( + batch_idx_ssa, + head_idx_ssa, + q_idx_ssa, + kv_idx_ssa, + self.seqlen_info, + aux_tensors, + ) + cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value)) + if const_expr(mask_seqlen): + out_of_bounds = (row_for_seqlen >= self.seqlen_q) or ( + global_col_idx >= self.seqlen_k + ) + if out_of_bounds: + acc_S_mn[r, col] = -cutlass.Float32.inf + else: + acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf + else: + acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf + + else: # Causal or local + if const_expr(not self.swap_AB): + # If PackGQA, we split the work of compute divmod among threads in the same row + threads_per_row = thr_mma.tv_layout_C.shape[0][0] + mma_m_idx = None + if const_expr(self.qhead_per_kvhead_packgqa != 1): + assert not self.swap_AB, "swap_AB with PackGQA not supported yet" + assert cute.arch.WARP_SIZE % threads_per_row == 0, ( + "threads_per_row must divide WARP_SIZE" + ) + assert cute.size(acc_S_mn.shape[0]) <= threads_per_row + tidx = thr_mma.thr_idx + mma_m_idx = ( + m_block * self.tile_m + tScS_mn[tidx % threads_per_row, 0][0] + ) // self.qhead_per_kvhead_packgqa + causal_row_offset = ( + 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q - thr_col_offset + ) + if const_expr(mask_causal): + r2p = const_expr(not self.swap_AB) # R2P trick, see apply_mask_sm100 + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): + # get the column index limit based on current row. Only consider the row index, so the column index sets to 0. + if const_expr(self.qhead_per_kvhead_packgqa == 1): + row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m + else: + row_idx = utils.shuffle_sync( + mma_m_idx, r % threads_per_row, width=threads_per_row + ) + col_limit_right = row_idx + causal_row_offset + if const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) + if const_expr(not r2p): + # traverse column index. + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): + acc_S_mn[r, c] = ( + -Float32.inf + if t0ScS_mn[0, c][1] >= col_limit_right + else acc_S_mn[r, c] + ) + else: + mask_r2p(acc_S_mn[r, None], col_limit_right, arch=90, rank1=True) + else: # Local + local_row_offset_right = ( + causal_row_offset + self.window_size_right + if const_expr(self.window_size_right is not None) + else None + ) + local_row_offset_left = ( + causal_row_offset - 1 - self.window_size_left + if const_expr(self.window_size_left is not None) + else None + ) + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): + if const_expr(self.qhead_per_kvhead_packgqa == 1): + row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m + else: + row_idx = utils.shuffle_sync( + mma_m_idx, r % threads_per_row, width=threads_per_row + ) + if const_expr(self.window_size_right is not None): + col_limit_right = row_idx + local_row_offset_right + else: + col_limit_right = self.tile_n + if const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) + col_limit_left = ( + row_idx + local_row_offset_left + if const_expr(self.window_size_left is not None) + else 0 + ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block = {}, r = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", n_block, r, row_idx, causal_row_offset, col_limit_right, col_limit_left) + # traverse column index. + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): + col_idx = t0ScS_mn[0, c][1] + # only consider the column index, so the row index sets to 0. + if col_idx >= col_limit_right or col_idx < col_limit_left: + acc_S_mn[r, c] = -Float32.inf + else: # swap_AB + assert self.qhead_per_kvhead_packgqa == 1 + thr_row_offset = tScS_mn[0][ROW] + causal_row_offset = ( + seqlenk_col_limit - self.seqlen_q + m_block * self.tile_m + thr_row_offset + ) + if const_expr(mask_causal): + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): + col0 = t0ScS_mn[0, c][COL] + # If col0 is beyond the column limit, we want to mask out the entire + # column, by setting row limit to be self.tile_m. + row_limit_top = ( + self.tile_m + if col0 >= seqlenk_col_limit and mask_seqlen + else col0 - causal_row_offset + ) + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): + acc_S_mn[r, c] = ( + -Float32.inf + if t0ScS_mn[r, 0][ROW] < row_limit_top + else acc_S_mn[r, c] + ) + else: + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): + col0 = t0ScS_mn[0, c][COL] + # If col0 is beyond the column limit, we want to mask out the entire + # column, by setting row limit to be self.tile_m. + row_limit_top = ( + self.tile_m + if col0 >= seqlenk_col_limit + else col0 - causal_row_offset - self.window_size_right + ) + # TODO: do we need col_limit_sink? + row_limit_bot = col0 - causal_row_offset + self.window_size_left + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): + row_idx = t0ScS_mn[r, 0][ROW] + acc_S_mn[r, c] = ( + -Float32.inf + if row_idx < row_limit_top or row_idx > row_limit_bot + else acc_S_mn[r, c] + ) + + @cute.jit + def apply_mask_sm100( + self, + acc_S: cute.Tensor, + m_block: Int32, + n_block: Int32, + thr_mma: cute.TiledMma, + thr_tmem_load: cute.TiledCopy, + mask_seqlen: cutlass.Constexpr[bool], + mask_causal: cutlass.Constexpr[bool], + mask_local: cutlass.Constexpr[bool] = False, + mask_mod: cutlass.Constexpr[Optional[Callable]] = None, + batch_idx: Int32 = None, + head_idx: Int32 = None, + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), + head_divmod=None, + check_q_boundary: bool = False, + ) -> None: + assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" + acc_shape = (self.tile_m, self.tile_n) + cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1]) + tScS = thr_mma.partition_C(cS) + tScS_t2r = thr_tmem_load.partition_D(tScS) + # To handle edge cases of completely masked out rows where n_block_max = 0, + # we treat negative n_blocks as 0th n_block + # TODO: find more transparent solution + if n_block < 0: + n_block = 0 + seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n + r2p = True + if const_expr(not mask_causal and not mask_local and mask_mod is None): + if const_expr(mask_seqlen): + if const_expr(not r2p): + for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True): + # if tScS_t2r[i][1] >= seqlenk_col_limit: + # acc_S[i] = -Float32.inf + # For some reason the 2 lines above generate really bad SASS + acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i] + else: + mask_r2p(acc_S, seqlenk_col_limit, arch=100, rank1=True) + + elif const_expr(not mask_causal and not mask_local and mask_mod is not None): + # Block sparse case w/ mask_mod + has_fastdiv = const_expr( + fastdiv_mods is not None + and fastdiv_mods[0] is not None + and fastdiv_mods[1] is not None + ) + batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) + + ncol = const_expr(cute.size(tScS_t2r.shape)) + for i in cutlass.range_constexpr(ncol): + row_coord = tScS_t2r[i][0] if not self.swap_AB else tScS_t2r[i][1] + col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0] + global_row = row_coord + m_block * self.tile_m + global_col = col_coord + n_block * self.tile_n + + if const_expr(self.qhead_per_kvhead_packgqa != 1): + assert head_divmod is not None + mask_row, head_offset = divmod(global_row, head_divmod) + head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset + else: + head_idx_for_mod = head_idx + mask_row = global_row + + mask_row_for_mod = mask_row + if const_expr(has_fastdiv and aux_tensors is not None): + if check_q_boundary: + _, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0]) + global_col_for_mod = global_col + if const_expr(has_fastdiv and mask_seqlen and aux_tensors is not None): + _, global_col_for_mod = divmod(global_col, fastdiv_mods[1]) + + head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32) + mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32) + kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32) + mask_value = mask_mod( + batch_idx_ssa, + head_idx_ssa, + mask_row_ssa, + kv_idx_ssa, + self.seqlen_info, + aux_tensors, + ) + cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value)) + acc_S[i] = acc_S[i] if cond else -Float32.inf + if const_expr(mask_seqlen): + acc_S[i] = -Float32.inf if global_col >= self.seqlen_k else acc_S[i] + if check_q_boundary: + acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i] + + else: # Causal or local + causal_row_offset = 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q + row_idx = tScS_t2r[0][0] + m_block * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa != 1): + row_idx = row_idx // self.qhead_per_kvhead_packgqa + if const_expr(mask_causal): + col_limit_right = row_idx + causal_row_offset + if const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) + # if cute.arch.thread_idx()[0] % 32 == 0: + # cute.printf("tidx = %d, tidx tmem = %d, row_idx = %d, col_limit_right = %d, causal_row_offset = %d\n", cute.arch.thread_idx()[0], thr_tmem_load.thr_idx, row_idx, col_limit_right, causal_row_offset) + ncol = const_expr(cute.size(tScS_t2r.shape)) + if const_expr(not r2p): + for i in cutlass.range(ncol, unroll_full=True): + acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i] + else: + mask_r2p(acc_S, col_limit_right, arch=100, rank1=True) + else: + local_row_offset_right = ( + causal_row_offset + self.window_size_right + if const_expr(self.window_size_right is not None) + else None + ) + local_row_offset_left = ( + causal_row_offset - 1 - self.window_size_left + if const_expr(self.window_size_left is not None) + else None + ) + if const_expr(self.window_size_right is not None): + col_limit_right = row_idx + local_row_offset_right + else: + col_limit_right = self.tile_n + if const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) + col_limit_left = ( + row_idx + local_row_offset_left + if const_expr(self.window_size_left is not None) + else 0 + ) + if const_expr(not r2p): + # if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left) + for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True): + col_idx = tScS_t2r[i][1] + acc_S[i] = ( + -Float32.inf + if col_idx >= col_limit_right or col_idx < col_limit_left + else acc_S[i] + ) + else: + # XOR-based R2P dual bound masking + mask_r2p_dual_bound(acc_S, col_limit_left, col_limit_right) + + @cute.jit + def apply_mask_sm100_transposed( + self, + acc_S: cute.Tensor, + tScS_t2r: cute.Tensor, + t0ScS_t2r: cute.Tensor, + m_block: cutlass.Int32, + n_block: cutlass.Int32, + mask_seqlen: cutlass.Constexpr, + mask_causal: cutlass.Constexpr, + mask_local: cutlass.Constexpr, + mask_mod: cutlass.Constexpr[Optional[Callable]] = None, + batch_idx: Int32 = None, + head_idx: Int32 = None, + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), + is_full_block: bool = False, + check_m_boundary: bool = True, + ) -> None: + """ + Backward pass: mask S = K @ Q.T where n_block tiles seqlen_k and m_block tiles seqlen_q. + + Coordinate conventio: + - ROW corresponds to Q (m_block) + - COL corresponds to KV (n_block) + + is_full_block: If True, skip mask_mod (all elements valid). Only apply seqlen masking. + check_m_boundary: If False, skip seqlen_q boundary check (optimization for non-boundary m_blocks). + When iterating m_blocks in forward order, only the last m_block may be partial. + """ + assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" + ROW = 0 if const_expr(not self.swap_AB) else 1 + COL = 1 if const_expr(not self.swap_AB) else 0 + assert t0ScS_t2r[0][COL] == 0, "col0 == 0" + thr_col_offset = tScS_t2r[0][COL] + seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset + + if const_expr(not mask_causal and not mask_local and mask_mod is not None): + # Block sparse case with mask_mod (backward) + # + # Coordinate convention: ROW → Q (m_block), COL → KV (n_block). + # These already account for swap_AB. + # + # FULL blocks: mask_mod returns True for all elements, so skip it. + # Still need seqlen bounds check (elements may be OOB on last m_block). + # PARTIAL blocks: apply mask_mod element-wise, then seqlen bounds. + if is_full_block: + if const_expr(mask_seqlen): + if seqlenk_col_limit <= 0: + # Entire tile is OOB for K + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): + acc_S[i] = -cutlass.Float32.inf + elif check_m_boundary: + # Last m_block: check Q and K boundaries + ncol = const_expr(cute.size(tScS_t2r.shape)) + for i in cutlass.range_constexpr(ncol): + row_coord = tScS_t2r[i][ROW] + col_coord = tScS_t2r[i][COL] + global_q = row_coord + m_block * self.tile_m + global_kv = col_coord + n_block * self.tile_n + q_out_of_bounds = global_q >= self.seqlen_q + kv_out_of_bounds = global_kv >= self.seqlen_k + out_of_bounds = q_out_of_bounds or kv_out_of_bounds + acc_S[i] = -cutlass.Float32.inf if out_of_bounds else acc_S[i] + else: + # Partial block + has_fastdiv = const_expr( + fastdiv_mods is not None + and fastdiv_mods[0] is not None + and fastdiv_mods[1] is not None + ) + wrap_aux_indices = const_expr( + has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None) + ) + batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) + head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32) + + ncol = const_expr(cute.size(tScS_t2r.shape)) + for i in cutlass.range_constexpr(ncol): + row_coord = tScS_t2r[i][ROW] + col_coord = tScS_t2r[i][COL] + global_q = row_coord + m_block * self.tile_m + global_kv = col_coord + n_block * self.tile_n + + q_idx_for_mod = global_q + kv_idx_for_mod = global_kv + if const_expr(wrap_aux_indices): + _, q_idx_for_mod = divmod(global_q, fastdiv_mods[0]) + _, kv_idx_for_mod = divmod(global_kv, fastdiv_mods[1]) + + q_idx_ssa = utils.scalar_to_ssa(q_idx_for_mod, cutlass.Int32) + kv_idx_ssa = utils.scalar_to_ssa(kv_idx_for_mod, cutlass.Int32) + + mask_value = mask_mod( + batch_idx_ssa, + head_idx_ssa, + q_idx_ssa, + kv_idx_ssa, + self.seqlen_info, + aux_tensors, + ) + cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value)) + acc_S[i] = acc_S[i] if cond else -cutlass.Float32.inf + + if const_expr(mask_seqlen): + # check_m_boundary=False skips q check for non-boundary m_blocks + q_out_of_bounds = check_m_boundary and (global_q >= self.seqlen_q) + kv_out_of_bounds = global_kv >= self.seqlen_k + out_of_bounds = q_out_of_bounds or kv_out_of_bounds + acc_S[i] = -cutlass.Float32.inf if out_of_bounds else acc_S[i] + + elif const_expr(not mask_causal and not mask_local): + if const_expr(mask_seqlen): + if seqlenk_col_limit <= 0: + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): + acc_S[i] = -cutlass.Float32.inf + else: # Causal or local + thr_row_offset = tScS_t2r[0][ROW] + seqlenq_row_limit = self.seqlen_q - m_block * self.tile_m - thr_row_offset + causal_offset = seqlenq_row_limit - seqlenk_col_limit + if const_expr(mask_causal): + # tidx = cute.arch.thread_idx()[0] % 256 + # if tidx < 32: + # cute.printf("tidx = {}, {} {}, {} {}", tidx, tScS_t2r[0][0], tScS_t2r[0][1], tScS_t2r[1][0], tScS_t2r[1][1]) + row_limit_top = causal_offset + if const_expr(mask_seqlen): + # If col is beyond the column limit, we want to mask out the entire + # column, by setting row limit to be self.tile_m. + if seqlenk_col_limit <= 0: + row_limit_top = self.tile_m + r2p = True + if const_expr(not r2p): + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): + acc_S[i] = ( + -cutlass.Float32.inf if t0ScS_t2r[i][ROW] < row_limit_top else acc_S[i] + ) + else: + num_rep = cute.size(tScS_t2r, mode=[0]) # 16 or 32 + mask_r2p_transposed(acc_S, row_limit_top, num_rep) + else: + if const_expr(self.window_size_right is not None): + row_limit_top = causal_offset - self.window_size_right + else: + row_limit_top = 0 + if const_expr(self.window_size_left is not None): + row_limit_bot = causal_offset + self.window_size_left + if const_expr(mask_seqlen): + if seqlenk_col_limit <= 0: + row_limit_top = self.tile_m + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): + row_idx = t0ScS_t2r[i][ROW] + local_mask = row_idx < row_limit_top + if const_expr(self.window_size_left is not None): + local_mask |= row_idx > row_limit_bot + acc_S[i] = -cutlass.Float32.inf if local_mask else acc_S[i] diff --git a/flash_attn/cute/mma_sm100_desc.py b/flash_attn/cute/mma_sm100_desc.py new file mode 100644 index 00000000000..16336c34686 --- /dev/null +++ b/flash_attn/cute/mma_sm100_desc.py @@ -0,0 +1,291 @@ +# Copyright (c) 2025, Tri Dao. +# Ported Cutlass code from C++ to Python: +# https://github.com/NVIDIA/cutlass/blob/main/include/cute/arch/mma_sm100_desc.hpp +# https://github.com/NVIDIA/cutlass/blob/main/include/cute/atom/mma_traits_sm100.hpp + +from enum import IntEnum + +import cutlass +import cutlass.cute as cute + +# --------------------------------------------------------------------------- +# Enumerations that match the HW encodings (values MUST stay identical) +# --------------------------------------------------------------------------- + + +class Major(IntEnum): # matrix “layout” in the ISA docs + K = 0 + MN = 1 + + +class ScaleIn(IntEnum): # negate flags + One = 0 + Neg = 1 + + +class Saturate(IntEnum): + False_ = 0 + True_ = 1 + + +class CFormat(IntEnum): # 2-bit field (bits 4-5) + F16 = 0 + F32 = 1 + S32 = 2 + + +class F16F32Format(IntEnum): # 3-bit field (A/B element type) + F16 = 0 + BF16 = 1 + TF32 = 2 + + +class S8Format(IntEnum): + UINT8 = 0 + INT8 = 1 + + +class MXF8F6F4Format(IntEnum): + E4M3 = 0 + E5M2 = 1 + E2M3 = 3 + E3M2 = 4 + E2M1 = 5 + + +class MaxShift(IntEnum): + NoShift = 0 + MaxShift8 = 1 + MaxShift16 = 2 + MaxShift32 = 3 + + +# --------------------------------------------------------------------------- +# CUTLASS-type → encoding helpers +# --------------------------------------------------------------------------- + + +def to_UMMA_format(cutlass_type) -> int: + """ + Map a CUTLASS scalar class to the 3-bit encoding for Matrix A/B. + """ + if cutlass_type is cutlass.Int8: + return S8Format.INT8 + # Unsigned 8-bit (if available in your CUTLASS build) + if cutlass_type is cutlass.Uint8: + return S8Format.UINT8 + # FP-16 / BF-16 + if cutlass_type is cutlass.Float16: + return F16F32Format.F16 + if cutlass_type is cutlass.BFloat16: + return F16F32Format.BF16 + # TensorFloat-32 (8-bit exponent, 10-bit mantissa packed in 19 bits) + if cutlass_type is cutlass.TFloat32: + return F16F32Format.TF32 + # Float-8 / Float-6 / Float-4 – add whenever CUTLASS exposes them + if cutlass_type is cutlass.FloatE4M3FN: + return MXF8F6F4Format.E4M3 + if cutlass_type is cutlass.FloatE5M2: + return MXF8F6F4Format.E5M2 + raise TypeError(f"Unsupported CUTLASS scalar type for A/B: {cutlass_type!r}") + + +def to_C_format(cutlass_type) -> int: + """ + Map a CUTLASS scalar class to the 2-bit accumulator encoding. + """ + if cutlass_type is cutlass.Float16: + return CFormat.F16 + if cutlass_type is cutlass.Float32: + return CFormat.F32 + if cutlass_type is cutlass.Int32: + return CFormat.S32 + raise TypeError(f"Unsupported CUTLASS scalar type for accumulator: {cutlass_type!r}") + + +# --------------------------------------------------------------------------- +# The constructor – accepts only CUTLASS scalar classes +# --------------------------------------------------------------------------- + + +def make_instr_desc( + a_type, # CUTLASS scalar class, e.g. cutlass.Int8 + b_type, + c_type, + M: int, # 64, 128 or 256 + N: int, # 8 … 256 (multiple of 8) + a_major: Major, + b_major: Major, + a_neg: ScaleIn = ScaleIn.One, + b_neg: ScaleIn = ScaleIn.One, + c_sat: Saturate = Saturate.False_, + is_sparse: bool = False, + max_shift: MaxShift = MaxShift.NoShift, +) -> int: + """ + Build the 32-bit instruction descriptor for Blackwell MMA. + All matrix/accumulator **types must be CUTLASS scalar classes** – + passing integers is forbidden. + """ + # --- encode element formats ------------------------------------------------- + a_fmt = int(to_UMMA_format(a_type)) + b_fmt = int(to_UMMA_format(b_type)) + c_fmt = int(to_C_format(c_type)) + + # --- range checks on M/N ----------------------------------------------------- + if M not in (64, 128, 256): + raise ValueError("M must be 64, 128 or 256") + if N < 8 or N > 256 or (N & 7): + raise ValueError("N must be a multiple of 8 in the range 8…256") + + m_dim = M >> 4 # 5-bit field + n_dim = N >> 3 # 6-bit field + + # fmt: off + # --- pack the bit-fields ----------------------------------------------------- + desc = 0 + desc |= (0 & 0x3) << 0 # sparse_id2 (always 0 here) + desc |= (int(is_sparse) & 0x1) << 2 # sparse_flag + desc |= (int(c_sat) & 0x1) << 3 # saturate + desc |= (c_fmt & 0x3) << 4 # c_format + desc |= (a_fmt & 0x7) << 7 # a_format + desc |= (b_fmt & 0x7) << 10 # b_format + desc |= (int(a_neg) & 0x1) << 13 # a_negate + desc |= (int(b_neg) & 0x1) << 14 # b_negate + desc |= (int(a_major) & 0x1) << 15 # a_major + desc |= (int(b_major) & 0x1) << 16 # b_major + desc |= (n_dim & 0x3F) << 17 # n_dim (6 bits) + desc |= (m_dim & 0x1F) << 24 # m_dim (5 bits) + desc |= (int(max_shift) & 0x3) << 30 # max_shift (2 bits) + # fmt: on + + return desc & 0xFFFF_FFFF # ensure 32-bit result + + +def mma_op_to_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp): + return make_instr_desc( + op.a_dtype, + op.b_dtype, + op.acc_dtype, + op.shape_mnk[0], + op.shape_mnk[1], + Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN, + Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN, + ) + + +class LayoutType(IntEnum): # occupies the top-3 bits [61:64) + SWIZZLE_NONE = 0 # (a.k.a. “INTERLEAVE” in older docs) + SWIZZLE_128B_BASE32B = 1 + SWIZZLE_128B = 2 + SWIZZLE_64B = 4 + SWIZZLE_32B = 6 + # values 3,5,7 are reserved / illegal for UMMA + + +# --------------------------------------------------------------------------- +# Helpers – figure out the SWIZZLE_* family from the tensor layout +# --------------------------------------------------------------------------- + + +def _layout_type(swizzle: cute.Swizzle) -> LayoutType: + # No idea what the right way to get B, M, S is – so we're just parsing it from the __str__ + # Swizzle string has the form "S" + swz_str = str(swizzle) + inside = swz_str[swz_str.index("<") + 1 : swz_str.index(">")] # '3,4,3' + B, M, S = [int(x) for x in inside.split(",")] # [3, 4, 3] + + if M == 4: # Swizzle<*,4,3> + if S != 3: + raise ValueError("Unexpected swizzle shift – want S==3 for M==4") + return { + 0: LayoutType.SWIZZLE_NONE, + 1: LayoutType.SWIZZLE_32B, + 2: LayoutType.SWIZZLE_64B, + 3: LayoutType.SWIZZLE_128B, + }[B] # KeyError ⇒ invalid B→ raise + if M == 5: # Swizzle<2,5,2> (the only legal triple for M==5) + if (B, S) != (2, 2): + raise ValueError("Only Swizzle<2,5,2> supported for 128B_BASE32B") + return LayoutType.SWIZZLE_128B_BASE32B + + # Any other (M,B,S) triple is not a UMMA-legal shared-memory layout + raise ValueError("Unsupported swizzle triple for UMMA smem descriptor") + + +def make_smem_desc_base(layout: cute.Layout, swizzle: cute.Swizzle, major: Major) -> int: + """ + Convert a 2-D *shared-memory* Cute layout into the Blackwell 64-bit + smem-descriptor, without the smem start address. + layout must correspond to layout of an uint128 tensor. + """ + # ------------------------------------------------------------------ meta + layout_type = _layout_type(swizzle) # resolve SWIZZLE_* family + + VERSION = 1 # bits 46–47 + LBO_MODE = 0 # bit 52 + BASE_OFFSET = 0 # bits 49–51 (CUTLASS always 0) + + # ---------------------------------------------------------- strides (units: uint128_t = 16 B) + swizzle_atom_mn_size = { + LayoutType.SWIZZLE_NONE: 1, + LayoutType.SWIZZLE_32B: 2, + LayoutType.SWIZZLE_64B: 4, + LayoutType.SWIZZLE_128B: 8, + LayoutType.SWIZZLE_128B_BASE32B: 8, + }[layout_type] + + if major is Major.MN: + swizzle_atom_k_size = 4 if layout_type is LayoutType.SWIZZLE_128B_BASE32B else 8 + canonical_layout = cute.logical_divide(layout, (swizzle_atom_mn_size, swizzle_atom_k_size)) + if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))): + raise ValueError("Not a canonical UMMA_MN Layout: Expected profile failure.") + stride_00 = canonical_layout.stride[0][0] + if layout_type is not LayoutType.SWIZZLE_NONE and stride_00 != 1: + raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.") + stride_10 = canonical_layout.stride[1][0] + if stride_10 != swizzle_atom_mn_size: + raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.") + stride_01, stride_11 = canonical_layout.stride[0][1], canonical_layout.stride[1][1] + if layout_type is LayoutType.SWIZZLE_NONE: + stride_byte_offset, leading_byte_offset = stride_01, stride_11 + else: + stride_byte_offset, leading_byte_offset = stride_11, stride_01 + else: + if layout_type == LayoutType.SWIZZLE_128B_BASE32B: + raise ValueError("SWIZZLE_128B_BASE32B is invalid for Major-K") + if not cute.size(layout.shape[0]) % 8 == 0: + raise ValueError("Not a canonical UMMA_K Layout: Expected MN-size multiple of 8.") + canonical_layout = cute.logical_divide(layout, (8, 2)) + if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))): + raise ValueError("Not a canonical UMMA_K Layout: Expected profile failure.") + stride_00 = canonical_layout.stride[0][0] + if stride_00 != swizzle_atom_mn_size: + raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.") + stride_10 = canonical_layout.stride[1][0] + if layout_type is not LayoutType.SWIZZLE_NONE and stride_10 != 1: + raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.") + stride_01 = canonical_layout.stride[0][1] + stride_byte_offset, leading_byte_offset = stride_01, stride_10 + + # ------------------------------------------------------------------ pack + desc = 0 + # leading_byte_offset_ [16:30) + desc |= (leading_byte_offset & 0x3FFF) << 16 + # stride_byte_offset_ [32:46) + desc |= (stride_byte_offset & 0x3FFF) << 32 + # version_ [46:48) + desc |= (VERSION & 0x3) << 46 + # base_offset_ [49:52) + desc |= (BASE_OFFSET & 0x7) << 49 + # lbo_mode_ [52:53) + desc |= (LBO_MODE & 0x1) << 52 + # layout_type_ [61:64) + desc |= (int(layout_type) & 0x7) << 61 + + return desc & 0xFFFF_FFFF_FFFF_FFFF # force 64-bit width + + +def make_smem_desc_start_addr(start_addr: cute.Pointer) -> cutlass.Int32: + # 14 bits, remove 4 LSB (bits 0-13 in desc) + return (start_addr.toint() & 0x3FFFF) >> 4 diff --git a/flash_attn/cute/named_barrier.py b/flash_attn/cute/named_barrier.py new file mode 100644 index 00000000000..777c44079a0 --- /dev/null +++ b/flash_attn/cute/named_barrier.py @@ -0,0 +1,31 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + +import enum + + +class NamedBarrierFwd(enum.IntEnum): + Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() + WarpSchedulerWG1 = enum.auto() + WarpSchedulerWG2 = enum.auto() + WarpSchedulerWG3 = enum.auto() + PFull = enum.auto() + PEmpty = enum.auto() + + +class NamedBarrierBwd(enum.IntEnum): + Epilogue = enum.auto() + WarpSchedulerWG1 = enum.auto() + WarpSchedulerWG2 = enum.auto() + WarpSchedulerWG3 = enum.auto() + PdS = enum.auto() + dQFullWG0 = enum.auto() + dQFullWG1 = enum.auto() + dQEmptyWG0 = enum.auto() + dQEmptyWG1 = enum.auto() + + +class NamedBarrierBwdSm100(enum.IntEnum): + EpilogueWG1 = enum.auto() + EpilogueWG2 = enum.auto() + Compute = enum.auto() + dQaccReduce = enum.auto() diff --git a/flash_attn/cute/pack_gqa.py b/flash_attn/cute/pack_gqa.py new file mode 100644 index 00000000000..765e71307ad --- /dev/null +++ b/flash_attn/cute/pack_gqa.py @@ -0,0 +1,164 @@ +# Copyright (c) 2025, Tri Dao. + + +import cutlass +import cutlass.cute as cute + +import flash_attn.cute.utils as utils + + +class PackGQA: + def __init__( + self, + m_block_size: cutlass.Constexpr[int], + head_dim_padded: cutlass.Constexpr[int], + check_hdim_oob: cutlass.Constexpr[bool], + qhead_per_kvhead: cutlass.Constexpr[bool], + ): + self.m_block_size = m_block_size + self.head_dim_padded = head_dim_padded + self.check_hdim_oob = check_hdim_oob + self.qhead_per_kvhead = qhead_per_kvhead + + @cute.jit + def compute_ptr( + self, + tensor: cute.Tensor, + cRows: cute.Tensor, + tidx: cutlass.Int32, + block: cutlass.Int32, + threads_per_row: cutlass.Constexpr[int], + num_threads: cutlass.Constexpr[int], + ): + num_ptr_per_thread = cute.ceil_div(cute.size(cRows), threads_per_row) + tPrPtr = cute.make_fragment(num_ptr_per_thread, cutlass.Int64) + for i in cutlass.range_constexpr(num_ptr_per_thread): + row = i * num_threads + cRows[tidx % threads_per_row][0] + idx = block * self.m_block_size + row + m_idx = idx // self.qhead_per_kvhead + h_idx = idx - m_idx * self.qhead_per_kvhead + tPrPtr[i] = utils.elem_pointer(tensor, ((h_idx, m_idx),)).toint() + return tPrPtr + + @cute.jit + def load_Q( + self, + mQ: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim) + sQ: cute.Tensor, # (m_block_size, head_dim_padded) + gmem_tiled_copy: cute.TiledCopy, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, + ): + gmem_thr_copy = gmem_tiled_copy.get_slice(tidx) + cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tQsQ = gmem_thr_copy.partition_D(sQ) + tQcQ = gmem_thr_copy.partition_S(cQ) + t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ) + tQpQ = utils.predicate_k(tQcQ, limit=mQ.shape[1]) + tQcQ_row = tQcQ[0, None, 0] + threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0] + assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" + num_threads = gmem_tiled_copy.size + tPrQPtr = self.compute_ptr(mQ[None, 0], tQcQ_row, tidx, block, threads_per_row, num_threads) + for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): + q_ptr_i64 = utils.shuffle_sync( + tPrQPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row + ) + q_gmem_ptr = cute.make_ptr( + mQ.element_type, q_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + if ( + t0QcQ[0, m, 0][0] + < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tQcQ_row[0][0] + ): + mQ_cur = cute.make_tensor(q_gmem_ptr, (self.head_dim_padded,)) + elems_per_load = cute.size(tQsQ.shape[0][0]) + mQ_cur_copy = cute.tiled_divide(mQ_cur, (elems_per_load,)) + for k in cutlass.range_constexpr(cute.size(tQsQ.shape[2])): + ki = tQcQ[0, 0, k][1] // elems_per_load + cute.copy( + gmem_thr_copy, + mQ_cur_copy[None, ki], + tQsQ[None, m, k], + pred=tQpQ[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None, + ) + # We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + + @cute.jit + def store_LSE( + self, + mLSE: cute.Tensor, # (qhead_per_kvhead, seqlen_q) + tLSErLSE: cute.Tensor, # (m_block_size, head_dim_padded) + tiled_mma: cute.TiledMma, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, + ): + thr_mma = tiled_mma.get_slice(tidx) + caccO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + taccOcO = thr_mma.partition_C(caccO) + taccOcO_row = utils.make_acc_tensor_mn_view(taccOcO)[None, 0] + assert cute.size(tLSErLSE) == cute.size(taccOcO_row) + threads_per_row = tiled_mma.tv_layout_C.shape[0][0] + assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" + assert cute.size(tLSErLSE) <= threads_per_row + num_threads = tiled_mma.size + tPrLSEPtr = self.compute_ptr(mLSE, taccOcO_row, tidx, block, threads_per_row, num_threads) + for m in cutlass.range_constexpr(cute.size(tLSErLSE)): + lse_ptr_i64 = utils.shuffle_sync( + tPrLSEPtr[m // threads_per_row], + m % threads_per_row, + width=threads_per_row, + ) + lse_gmem_ptr = cute.make_ptr( + mLSE.element_type, lse_ptr_i64, cute.AddressSpace.gmem, assumed_align=4 + ) + row = block * self.m_block_size + taccOcO_row[m][0] + # Only the thread corresponding to column 0 writes out the lse to gmem + if taccOcO[0][1] == 0 and row < seqlen * self.qhead_per_kvhead: + mLSE_copy = cute.make_tensor(lse_gmem_ptr, (1,)) + mLSE_copy[0] = tLSErLSE[m] + + @cute.jit + def store_O( + self, + mO: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim) + tOrO: cute.Tensor, # (m_block_size, head_dim_padded) split across threads according to gmem_tiled_copy + gmem_tiled_copy: cute.TiledCopy, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, + ): + gmem_thr_copy = gmem_tiled_copy.get_slice(tidx) + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tOcO = gmem_thr_copy.partition_S(cO) + t0OcO = gmem_thr_copy.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) + tOcO_row = tOcO[0, None, 0] + threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0] + assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" + num_threads = gmem_tiled_copy.size + tPrOPtr = self.compute_ptr(mO[None, 0], tOcO_row, tidx, block, threads_per_row, num_threads) + for m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + o_ptr_i64 = utils.shuffle_sync( + tPrOPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row + ) + o_gmem_ptr = cute.make_ptr( + mO.element_type, o_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + if ( + t0OcO[0, m, 0][0] + < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tOcO_row[0][0] + ): + mO_cur = cute.make_tensor(o_gmem_ptr, (self.head_dim_padded,)) + elems_per_load = cute.size(tOrO.shape[0][0]) + mO_cur_copy = cute.tiled_divide(mO_cur, (elems_per_load,)) + for k in cutlass.range_constexpr(cute.size(tOrO.shape[2])): + ki = tOcO[0, 0, k][1] // elems_per_load + cute.copy( + gmem_thr_copy, + tOrO[None, m, k], + mO_cur_copy[None, ki], + pred=tOpO[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None, + ) diff --git a/flash_attn/cute/paged_kv.py b/flash_attn/cute/paged_kv.py new file mode 100644 index 00000000000..e2d2d84433d --- /dev/null +++ b/flash_attn/cute/paged_kv.py @@ -0,0 +1,214 @@ +from typing import Type +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync +from cutlass import Int32, const_expr + +from flash_attn.cute import utils +from flash_attn.cute.cute_dsl_utils import ParamsBase +from cutlass.cute import FastDivmodDivisor + +import math + + +@dataclass +class PagedKVManager(ParamsBase): + mPageTable: cute.Tensor + mK_paged: cute.Tensor + mV_paged: cute.Tensor + thread_idx: Int32 + + page_size_divmod: FastDivmodDivisor + seqlen_k: Int32 + leftpad_k: Int32 + n_block_size: Int32 + num_threads: cutlass.Constexpr[Int32] + head_dim_padded: cutlass.Constexpr[Int32] + head_dim_v_padded: cutlass.Constexpr[Int32] + + gmem_threads_per_row: cutlass.Constexpr[Int32] + page_entry_per_thread: Int32 + async_copy_elems: Int32 + + gmem_tiled_copy_KV: cute.TiledCopy + gmem_thr_copy_KV: cute.TiledCopy + tPrPage: cute.Tensor + tPrPageOffset: cute.Tensor + tKpK: cute.Tensor + tVpV: cute.Tensor + + @staticmethod + def create( + mPageTable: cute.Tensor, + mK_paged: cute.Tensor, + mV_paged: cute.Tensor, + page_size_divmod: FastDivmodDivisor, + bidb: Int32, + bidh: Int32, + thread_idx: Int32, + seqlen_k: Int32, + leftpad_k: Int32, + n_block_size: cutlass.Constexpr[Int32], + head_dim_padded: cutlass.Constexpr[Int32], + head_dim_v_padded: cutlass.Constexpr[Int32], + num_threads: cutlass.Constexpr[Int32], + dtype: Type[cutlass.Numeric], + ): + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // dtype.width + dtype_bytes = dtype.width // 8 + gmem_k_block_size = math.gcd( + head_dim_padded, + head_dim_v_padded, + 128 // dtype_bytes, + ) + assert gmem_k_block_size % async_copy_elems == 0 + gmem_threads_per_row = gmem_k_block_size // async_copy_elems + assert cute.arch.WARP_SIZE % gmem_threads_per_row == 0 + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + dtype, + num_bits_per_copy=universal_copy_bits, + ) + thr_layout = cute.make_ordered_layout( + (num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, async_copy_elems)) + gmem_tiled_copy_KV = cute.make_tiled_copy_tv(atom_async_copy, thr_layout, val_layout) + gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(thread_idx) + page_entry_per_thread = n_block_size // num_threads + + tPrPage = cute.make_rmem_tensor((page_entry_per_thread,), Int32) + tPrPageOffset = cute.make_rmem_tensor((page_entry_per_thread,), Int32) + + mPageTable = mPageTable[bidb, None] + mK_paged = mK_paged[None, None, bidh, None] + mV_paged = mV_paged[None, None, bidh, None] + + cK = cute.make_identity_tensor((n_block_size, head_dim_padded)) + tKcK = gmem_thr_copy_KV.partition_S(cK) + tKpK = utils.predicate_k(tKcK, limit=mK_paged.shape[1]) + + if const_expr(head_dim_padded == head_dim_v_padded): + tVpV = tKpK + else: + cV = cute.make_identity_tensor((n_block_size, head_dim_v_padded)) + tVcV = gmem_thr_copy_KV.partition_S(cV) + tVpV = utils.predicate_k(tVcV, limit=mV_paged.shape[0]) + + return PagedKVManager( + mPageTable, + mK_paged, + mV_paged, + thread_idx, + page_size_divmod, + seqlen_k, + leftpad_k, + n_block_size, + num_threads, + head_dim_padded, + head_dim_v_padded, + gmem_threads_per_row, + page_entry_per_thread, + async_copy_elems, + gmem_tiled_copy_KV, + gmem_thr_copy_KV, + tPrPage, + tPrPageOffset, + tKpK, + tVpV, + ) + + @cute.jit + def load_page_table(self, n_block: Int32): + for i in cutlass.range(self.page_entry_per_thread, unroll=1): + row = ( + i * self.num_threads + + (self.thread_idx % self.gmem_threads_per_row) + * (self.num_threads // self.gmem_threads_per_row) + + (self.thread_idx // self.gmem_threads_per_row) + ) + row_idx = n_block * self.n_block_size + row + + page_idx, page_offset = divmod(row_idx + self.leftpad_k, self.page_size_divmod) + + is_valid = ( + (i + 1) * self.num_threads <= self.n_block_size or row < self.n_block_size + ) and row_idx < self.seqlen_k + page = self.mPageTable[page_idx] if is_valid else 0 + + self.tPrPage[i] = page + self.tPrPageOffset[i] = page_offset + + @cute.jit + def compute_X_ptr(self, K_or_V: str): + tPrXPtr = cute.make_rmem_tensor((self.page_entry_per_thread,), cutlass.Int64) + for i in cutlass.range(self.page_entry_per_thread, unroll=1): + page = self.tPrPage[i] + page_offset = self.tPrPageOffset[i] + if const_expr(K_or_V == "K"): + tPrXPtr[i] = utils.elem_pointer(self.mK_paged, (page_offset, 0, page)).toint() + else: + tPrXPtr[i] = utils.elem_pointer(self.mV_paged, (0, page_offset, page)).toint() + return tPrXPtr + + @cute.jit + def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str): + assert K_or_V in ("K", "V") + + tPrXPtr = self.compute_X_ptr(K_or_V) + + # Finesse sX layout to be (M, N). + sX_pi = cute.make_tensor( + sX.iterator, + cute.make_layout( + (sX.shape[0][0], (sX.shape[0][1], sX.shape[2])), + stride=(sX.stride[0][0], (sX.stride[0][1], sX.stride[2])), + ), + ) + + if const_expr(K_or_V == "V"): + # Need to transpose V + sX_pi = cute.make_tensor(sX_pi.iterator, cute.select(sX_pi.layout, mode=[1, 0])) + + head_dim = self.head_dim_v_padded if const_expr(K_or_V == "V") else self.head_dim_padded + cX = cute.make_identity_tensor((self.n_block_size, head_dim)) + tXsX = self.gmem_thr_copy_KV.partition_D(sX_pi) + tXcX = self.gmem_thr_copy_KV.partition_S(cX) + tXc0X = self.gmem_thr_copy_KV.get_slice(0).partition_S(cX) + + seqlenk_row_limit = ( + self.seqlen_k - n_block * self.n_block_size - tXcX[0][0] if n_block >= 0 else 0 + ) + for m in cutlass.range_constexpr(cute.size(tXsX, mode=[1])): + row_valid = tXc0X[0, m, 0][0] < seqlenk_row_limit + should_load = cute.make_fragment_like(tXsX[(0, None), m, 0], cute.Boolean) + should_load.fill(row_valid) + + x_ptr_i64 = utils.shuffle_sync( + tPrXPtr[m // self.gmem_threads_per_row], + m % self.gmem_threads_per_row, + width=self.gmem_threads_per_row, + ) + x_gmem_ptr = cute.make_ptr( + self.mK_paged.element_type, x_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + mX_paged_cur = cute.make_tensor(x_gmem_ptr, cute.make_layout((head_dim,))) + mX_paged_cur_copy = cute.tiled_divide(mX_paged_cur, (self.async_copy_elems,)) + + for k in cutlass.range_constexpr(cute.size(tXsX, mode=[2])): + ki = tXcX[0, 0, k][1] // self.async_copy_elems + mX_paged_cur_copy_ki = mX_paged_cur_copy[None, ki] + tXsX_k = tXsX[None, m, k] + mX_paged_cur_copy_ki = cute.make_tensor( + mX_paged_cur_copy_ki.iterator, tXsX_k.layout + ) + cute.copy( + self.gmem_tiled_copy_KV, + mX_paged_cur_copy_ki, + tXsX_k, + pred=should_load, + ) diff --git a/flash_attn/cute/pipeline.py b/flash_attn/cute/pipeline.py new file mode 100644 index 00000000000..54981bca127 --- /dev/null +++ b/flash_attn/cute/pipeline.py @@ -0,0 +1,272 @@ +# Copyright (c) 2025, Tri Dao. + +# import math +from typing import Optional +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Boolean, Int32, const_expr +from cutlass.cutlass_dsl import if_generate +from cutlass.pipeline import PipelineAsync, PipelineState, Agent, CooperativeGroup +from cutlass.pipeline import PipelineUserType, PipelineOp +from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg +from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg + + +# We deviate from cute-dsl implementation to use cute.arch.cluster_arrive_relaxed +def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None): + """ + Fences the mbarrier init and syncs the threadblock or cluster + """ + cute.arch.mbarrier_init_fence() + + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: + # If not using clusters, sync the threadblock + _sync(Agent.ThreadBlock) + else: + # If using clusters, sync the cluster + _sync(Agent.ThreadBlockCluster) + + +def _sync(group: Agent): + """ + Syncs all threads within an agent. + """ + if group is Agent.Thread: + raise NotImplementedError("Error: Not supported.") + elif group is Agent.ThreadBlock: + cute.arch.sync_threads() + elif group is Agent.ThreadBlockCluster: + cute.arch.cluster_arrive_relaxed() + cute.arch.cluster_wait() + else: + assert False, ( + "Error: No explicit sync instruction exists. Please use barriers (named / mbarrier) instead." + ) + + +class PipelineStateSimple: + """ + Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer. + Use a single Int32 to store both the index and phase bit, then we use divmod to get the + index and phase. If stages is a power of 2, divmod turns into bit twiddling. + """ + + def __init__(self, stages: int, phase_index: Int32): + # assert stages < 2**16 + # self._log_stages = int(math.log2(stages)) + # assert 1 << self._log_stages == stages, "Number of stages must be a power of 2." + self._stages = stages + self._phase_index = phase_index + + def clone(self) -> "PipelineStateSimple": + return PipelineStateSimple(self.stages, self._phase_index) + + @property + def stages(self) -> int: + # return 1 << self._log_stages + return self._stages + + @property + def index(self) -> Int32: + # return self._phase_index & 0xFFFF + # return self._phase_index & ((1 << self._log_stages) - 1) + if const_expr(self._stages == 1): + return Int32(0) + else: + return self._phase_index % self._stages + + @property + def phase(self) -> Int32: + # return self._phase_index >> 16 + # PTX docs say that the phase parity needs to be 0 or 1, so by right we need to + # take modulo 2. But in practice just passing the phase in without modulo works fine. + # return (self._phase_index >> self._log_stages) % 2 + # return self._phase_index >> self._log_stages + if const_expr(self._stages == 1): + return self._phase_index + else: + return self._phase_index // self._stages + + def advance(self): + if const_expr(self._stages == 1): + self._phase_index ^= 1 + else: + self._phase_index += 1 + + # def then_body(phase_index): + # # XOR the phase bit and set the index to 0 + # return (phase_index & 0xFFFF0000) ^ (1 << 16) + + # def else_body(phase_index): + # return phase_index + + # self._phase_index = if_generate( + # (self._phase_index & 0xFFFF) == self.stages, + # then_body, + # else_body, + # [self._phase_index], + # [Int32], + # ) + + def __extract_mlir_values__(self): + phase_index = self._phase_index + return [phase_index.ir_value()] + + def __new_from_mlir_values__(self, values): + return PipelineStateSimple(self.stages, Int32(values[0])) + + +def make_pipeline_state(type: PipelineUserType, stages: int): + """ + Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1. + """ + if type is PipelineUserType.Producer: + # return PipelineStateSimple(stages, Int32(1 << 16)) + return PipelineStateSimple(stages, Int32(stages)) + elif type is PipelineUserType.Consumer: + return PipelineStateSimple(stages, Int32(0)) + else: + assert False, "Error: invalid PipelineUserType specified for make_pipeline_state." + + +@dataclass(frozen=True) +class PipelineTmaAsync(PipelineTmaAsyncOg): + """ + Override producer_acquire to take in extra_tx_count parameter. + """ + + @staticmethod + def create(*args, **kwargs): + obj = PipelineTmaAsyncOg.create(*args, **kwargs) + # Can't assign to __class__ directly since the dataclass is frozen + # obj.__class__ = PipelineTmaAsync + object.__setattr__(obj, "__class__", PipelineTmaAsync) + return obj + + def producer_acquire( + self, + state: PipelineState, + try_acquire_token: Optional[Boolean] = None, + extra_tx_count: int = 0, + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase), + ) + if const_expr(extra_tx_count == 0): + self.sync_object_full.arrive(state.index, self.producer_mask) + else: + tx_count = self.sync_object_full.tx_count + extra_tx_count + self.sync_object_full.arrive_and_expect_tx(state.index, tx_count) + + +@dataclass(frozen=True) +class PipelineTmaUmma(PipelineTmaUmmaOg): + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + tx_count: int, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + mcast_mode_mn: tuple[int, int] = (1, 1), + init_wait: cutlass.Constexpr[bool] = True, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: `CooperativeGroup` for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: `CooperativeGroup` for the consumer agent + :type consumer_group: CooperativeGroup + :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage + :type tx_count: int + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout | None + :param mcast_mode_mn: Tuple of two integers, specifying whether mcast is enabled for the m and n modes. At least one of the two integers must be 1. + :type mcast_mode_mn: tuple[int, int] + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.TmaLoad + consumer_type = PipelineOp.TCGen05Mma + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer, tx_count + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: + # No mcast mask if not using clusters + producer_mask = None + # All threadblocks are leaders if not using clusters + is_leader_cta = True + else: + producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask( + cta_layout_vmnk, mcast_mode_mn + ) + is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk) + + cta_group = ( + cute.nvgpu.tcgen05.CtaGroup.ONE + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 + else cute.nvgpu.tcgen05.CtaGroup.TWO + ) + + consumer_mask = producer_mask + + if const_expr(init_wait): + pipeline_init_wait(cta_layout_vmnk) + + return PipelineTmaUmma( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + consumer_mask, + is_leader_cta, + cta_group, + ) + + def producer_acquire( + self, + state: PipelineState, + try_acquire_token: Optional[Boolean] = None, + extra_tx_count: int = 0, + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase), + ) + if const_expr(extra_tx_count == 0): + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive(state.index, self.producer_mask), + ) + else: + tx_count = self.sync_object_full.tx_count + extra_tx_count + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive_and_expect_tx(state.index, tx_count), + ) diff --git a/flash_attn/cute/pyproject.toml b/flash_attn/cute/pyproject.toml new file mode 100644 index 00000000000..1503556c122 --- /dev/null +++ b/flash_attn/cute/pyproject.toml @@ -0,0 +1,56 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +name = "flash-attn-cute" +version = "0.1.0" +description = "Flash Attention CUTE (CUDA Template Engine) implementation" +readme = "README.md" +requires-python = ">=3.10" +license = {text = "BSD 3-Clause License"} +authors = [ + {name = "Tri Dao"}, +] +classifiers = [ + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: BSD License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] + +dependencies = [ + "nvidia-cutlass-dsl>=4.3.5,<4.4.0", + "torch", + "einops", + "typing_extensions", + "apache-tvm-ffi>=0.1.5,<0.2", + "torch-c-dlpack-ext", + "quack-kernels==0.2.4", +] + +[project.optional-dependencies] +dev = [ + "pytest", + "ruff", +] + +[project.urls] +Homepage = "https://github.com/Dao-AILab/flash-attention" +Repository = "https://github.com/Dao-AILab/flash-attention" + +[tool.setuptools] +packages = ["flash_attn.cute"] +package-dir = {"flash_attn.cute" = "."} + +[tool.ruff] +line-length = 100 + +[tool.ruff.lint] +ignore = [ + "E731", # do not assign a lambda expression, use a def + "E741", # Do not use variables named 'I', 'O', or 'l' + "F841", # local variable is assigned to but never used +] diff --git a/flash_attn/cute/seqlen_info.py b/flash_attn/cute/seqlen_info.py new file mode 100644 index 00000000000..6d8c6feb279 --- /dev/null +++ b/flash_attn/cute/seqlen_info.py @@ -0,0 +1,138 @@ +from typing import Optional +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, const_expr + +""" +This consolidates all the info related to sequence length. This is so that we can do all +the gmem reads once at the beginning of each tile, rather than having to repeat these reads +to compute various things like n_block_min, n_block_max, etc. +""" + + +@dataclass(frozen=True) +class SeqlenInfo: + offset: cutlass.Int32 + seqlen: cutlass.Int32 + + @staticmethod + def create( + batch_idx: cutlass.Int32, + seqlen_static: cutlass.Int32, + cu_seqlens: Optional[cute.Tensor] = None, + seqused: Optional[cute.Tensor] = None, + ): + offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx] + if const_expr(seqused is not None): + seqlen = seqused[batch_idx] + elif const_expr(cu_seqlens is not None): + seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx] + else: + seqlen = seqlen_static + return SeqlenInfo(offset, seqlen) + + +@dataclass(frozen=True) +class SeqlenInfoQK: + offset_q: cutlass.Int32 + offset_k: cutlass.Int32 + padded_offset_q: cutlass.Int32 + padded_offset_k: cutlass.Int32 + seqlen_q: cutlass.Int32 + seqlen_k: cutlass.Int32 + has_cu_seqlens_q: cutlass.Constexpr[bool] + has_cu_seqlens_k: cutlass.Constexpr[bool] + has_seqused_q: cutlass.Constexpr[bool] + has_seqused_k: cutlass.Constexpr[bool] + + @staticmethod + def create( + batch_idx: cutlass.Int32, + seqlen_q_static: cutlass.Int32, + seqlen_k_static: cutlass.Int32, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + tile_m: cutlass.Constexpr[cutlass.Int32] = 128, + tile_n: cutlass.Constexpr[cutlass.Int32] = 128, + ): + offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx] + offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx] + padded_offset_q = ( + 0 + if const_expr(mCuSeqlensQ is None) + else (offset_q + batch_idx * tile_m) // tile_m * tile_m + ) + padded_offset_k = ( + 0 + if const_expr(mCuSeqlensK is None) + else (offset_k + batch_idx * tile_n) // tile_n * tile_n + ) + if const_expr(mSeqUsedQ is not None): + seqlen_q = mSeqUsedQ[batch_idx] + else: + seqlen_q = ( + seqlen_q_static + if const_expr(mCuSeqlensQ is None) + else mCuSeqlensQ[batch_idx + 1] - offset_q + ) + if const_expr(mSeqUsedK is not None): + seqlen_k = mSeqUsedK[batch_idx] + else: + seqlen_k = ( + seqlen_k_static + if const_expr(mCuSeqlensK is None) + else mCuSeqlensK[batch_idx + 1] - offset_k + ) + has_cu_seqlens_q: int = mCuSeqlensQ is not None + has_cu_seqlens_k: int = mCuSeqlensK is not None + has_seqused_q: int = mSeqUsedQ is not None + has_seqused_k: int = mSeqUsedK is not None + return SeqlenInfoQK( + offset_q, + offset_k, + padded_offset_q, + padded_offset_k, + seqlen_q, + seqlen_k, + has_cu_seqlens_q, + has_cu_seqlens_k, + has_seqused_q, + has_seqused_k, + ) + + def offset_batch_Q( + self, + mQ: cute.Tensor, + batch_idx: Int32, + dim: int, + padded: cutlass.Constexpr[bool] = False, + ) -> cute.Tensor: + """Seqlen must be the first dimension of mQ""" + if const_expr(not self.has_cu_seqlens_q): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim) + return mQ[idx] + else: + offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q + offset = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (0, offset_q) + idx = (offset,) + (0,) * (cute.rank(mQ) - 1) + return cute.domain_offset(idx, mQ) + + def offset_batch_K( + self, + mK: cute.Tensor, + batch_idx: Int32, + dim: int, + padded: cutlass.Constexpr[bool] = False, + ) -> cute.Tensor: + """Seqlen must be the first dimension of mK""" + if const_expr(not self.has_cu_seqlens_k): + idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim) + return mK[idx] + else: + offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k + idx = (offset_k,) + (0,) * (cute.rank(mK) - 1) + return cute.domain_offset(idx, mK) diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py new file mode 100644 index 00000000000..f0646c22714 --- /dev/null +++ b/flash_attn/cute/softmax.py @@ -0,0 +1,582 @@ +# Copyright (c) 2025, Tri Dao. + +import math +import operator +from typing import Tuple +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Float32 + +import flash_attn.cute.utils as utils +from flash_attn.cute.cute_dsl_utils import ParamsBase +from flash_attn.cute.seqlen_info import SeqlenInfoQK + + +@dataclass +class Softmax(ParamsBase): + scale_log2: Float32 + num_rows: cutlass.Constexpr[int] + row_max: cute.Tensor + row_sum: cute.Tensor + arch: cutlass.Constexpr[int] = 80 + softmax_scale: Float32 | None = None + + @staticmethod + def create( + scale_log2: Float32, + num_rows: cutlass.Constexpr[int], + arch: cutlass.Constexpr[int] = 80, + softmax_scale: Float32 | None = None, + ): + row_max = cute.make_rmem_tensor(num_rows, Float32) + row_sum = cute.make_rmem_tensor(num_rows, Float32) + return Softmax(scale_log2, num_rows, row_max, row_sum, arch, softmax_scale) + + def reset(self) -> None: + self.row_max.fill(-Float32.inf) + self.row_sum.fill(0.0) + + def _compute_row_max( + self, acc_S_row: cute.TensorSSA, init_val: float | Float32 | None = None + ) -> Float32: + return utils.fmax_reduce(acc_S_row, init_val, arch=self.arch) + + def _compute_row_sum( + self, acc_S_row_exp: cute.TensorSSA, init_val: float | Float32 | None = None + ) -> Float32: + return utils.fadd_reduce(acc_S_row_exp, init_val, arch=self.arch) + + @cute.jit + def online_softmax( + self, + acc_S: cute.Tensor, + is_first: cutlass.Constexpr[bool] = False, + check_inf: cutlass.Constexpr[bool] = True, + ) -> cute.Tensor: + """Apply online softmax and return the row_scale to rescale O. + + :param acc_S: acc_S tensor + :type acc_S: cute.Tensor + :param is_first: is first n_block + :type is_first: cutlass.Constexpr + """ + # Change acc_S to M,N layout view. + acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) + row_scale = cute.make_fragment_like(self.row_max, Float32) + + row_max = self.row_max + row_sum = self.row_sum + scale_log2 = self.scale_log2 + arch = self.arch + + # Each iteration processes one row of acc_S + for r in cutlass.range(cute.size(row_max), unroll_full=True): + acc_S_row = acc_S_mn[r, None].load() # (n_block_size) + + row_max_cur = utils.fmax_reduce( + acc_S_row, + init_val=row_max[r] if cutlass.const_expr(not is_first) else None, + arch=arch, + ) + + row_max_cur = utils.warp_reduce(row_max_cur, cute.arch.fmax, width=4) + # Update row_max before changing row_max_cur to safe value for -inf + row_max_prev = row_max[r] + row_max[r] = row_max_cur + + if cutlass.const_expr(check_inf): + row_max_cur = 0.0 if row_max_cur == -Float32.inf else row_max_cur + + if cutlass.const_expr(is_first): + row_max_cur_scaled = row_max_cur * scale_log2 + acc_S_row_exp = utils.exp2f(acc_S_row * scale_log2 - row_max_cur_scaled) + + acc_S_row_sum = utils.fadd_reduce(acc_S_row_exp, init_val=None, arch=arch) + row_scale[r] = 1.0 + else: + row_max_cur_scaled = row_max_cur * scale_log2 + acc_S_row_exp = utils.exp2f(acc_S_row * scale_log2 - row_max_cur_scaled) + # row_scale[r] = utils.exp2f(row_max_prev * self.scale_log2 - row_max_cur_scaled) + row_scale[r] = utils.exp2f((row_max_prev - row_max_cur) * scale_log2) + + acc_S_row_sum = utils.fadd_reduce( + acc_S_row_exp, init_val=row_sum[r] * row_scale[r], arch=arch + ) + + row_sum[r] = acc_S_row_sum + acc_S_mn[r, None].store(acc_S_row_exp) + + return row_scale + + @cute.jit + def finalize( + self, final_scale: Float32 = 1.0, sink_val: Float32 | cute.Tensor | None = None + ) -> cute.Tensor: + """Finalize the online softmax by computing the scale and logsumexp.""" + if cutlass.const_expr(sink_val is not None and isinstance(sink_val, cute.Tensor)): + assert cute.size(sink_val) == cute.size(self.row_sum) + row_sum = self.row_sum + row_max = self.row_max + scale_log2 = self.scale_log2 + + # quad reduction for row_sum as we didn't do it during each iteration of online softmax + row_sum.store(utils.warp_reduce(row_sum.load(), operator.add, width=4)) + row_scale = cute.make_fragment_like(row_max, Float32) + + for r in cutlass.range(cute.size(row_sum), unroll_full=True): + if cutlass.const_expr(sink_val is not None): + sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r] + LOG2_E = math.log2(math.e) + row_sum[r] += utils.exp2f(sink_val_cur * LOG2_E - row_max[r] * scale_log2) + + # if row_sum is zero or nan, set acc_O_mn_row to 1.0 + acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r] + row_scale[r] = ( + cute.arch.rcp_approx(row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0) + ) * final_scale + row_sum_cur = row_sum[r] + LN2 = math.log(2.0) + row_sum[r] = ( + (row_max[r] * scale_log2 + utils.log2f(row_sum_cur)) * LN2 + if not acc_O_mn_row_is_zero_or_nan + else -Float32.inf + ) + return row_scale + + @cute.jit + def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None: + """Scale each row of acc_O by the given scale tensor. + :param acc_O: input tensor + :type acc_O: cute.Tensor + :param row_scale: row_scale tensor + :type row_scale: cute.Tensor + """ + acc_O_mn = utils.make_acc_tensor_mn_view(acc_O) + assert cute.size(row_scale) == cute.size(acc_O_mn, mode=[0]) + for r in cutlass.range(cute.size(row_scale), unroll_full=True): + acc_O_mn[r, None].store(acc_O_mn[r, None].load() * row_scale[r]) + + +@dataclass +class SoftmaxSm100(Softmax): + rescale_threshold: cutlass.Constexpr[float] = 0.0 + + @staticmethod + def create( + scale_log2: Float32, + rescale_threshold: cutlass.Constexpr[float] = 0.0, + softmax_scale: Float32 | None = None, + ): + num_rows = 1 + arch = 100 + row_max = cute.make_rmem_tensor(num_rows, Float32) + row_sum = cute.make_rmem_tensor(num_rows, Float32) + return SoftmaxSm100( + scale_log2, + num_rows, + row_max, + row_sum, + arch, + softmax_scale, + rescale_threshold=rescale_threshold, + ) + + @cute.jit + def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]: + if cutlass.const_expr(is_first): + row_max_new = self._compute_row_max(acc_S_row) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale = 0.0 + else: + row_max_old = self.row_max[0] + row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2 + acc_scale = utils.exp2f(acc_scale_) + if cutlass.const_expr(self.rescale_threshold > 0.0): + if acc_scale_ >= -self.rescale_threshold: + row_max_new = row_max_old + row_max_safe = row_max_old + acc_scale = 1.0 + self.row_max[0] = row_max_new + return row_max_safe, acc_scale + + def update_row_sum( + self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32, is_first: int = False + ) -> None: + init_val = self.row_sum[0] * row_scale if cutlass.const_expr(not is_first) else None + # self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=self.row_sum[0] * row_scale) + self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=init_val) + # tmp = self._compute_row_sum(acc_S_row_exp) + # self.row_sum[0] = self.row_sum[0] * row_scale + tmp + + @cute.jit + def scale_subtract_rowmax( + self, + acc_S_row: cute.Tensor, + row_max: Float32, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + row_max_scaled = row_max * self.scale_log2 + for i in cutlass.range(0, cute.size(acc_S_row.shape), 2, unroll_full=True): + acc_S_row[i], acc_S_row[i + 1] = utils.fma_packed_f32x2( + (acc_S_row[i], acc_S_row[i + 1]), + (self.scale_log2, self.scale_log2), + (-row_max_scaled, -row_max_scaled), + ) + + @cute.jit + def apply_exp2_convert( + self, + acc_S_row: cute.Tensor, + acc_S_row_converted: cute.Tensor, + e2e: cutlass.Constexpr[bool] = False, + e2e_freq: cutlass.Constexpr[int] = 16, + e2e_res: cutlass.Constexpr[int] = 4, + e2e_frg_limit: cutlass.Constexpr[int] = 1, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + frg_tile = 32 + assert frg_tile % 2 == 0 + frg_cnt = cute.size(acc_S_row) // frg_tile + assert cute.size(acc_S_row) % frg_tile == 0 + acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide( + acc_S_row_converted, cute.make_layout(frg_tile) + ) + for j in cutlass.range_constexpr(frg_cnt): + for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): + # acc_S_row_frg[k, j] = utils.exp2f(acc_S_row_frg[k, j]) + # acc_S_row_frg[k + 1, j] = utils.exp2f(acc_S_row_frg[k + 1, j]) + if cutlass.const_expr(not e2e): + acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j]) + acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j]) + else: + if cutlass.const_expr( + k % e2e_freq < e2e_freq - e2e_res or j >= frg_cnt - e2e_frg_limit + ): + acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j]) + acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j]) + else: + # acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.e2e_asm2(acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]) + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.ex2_emulation_2( + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] + ) + acc_S_row_converted_frg[None, j].store( + acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) + ) + + @cute.jit + def scale_apply_exp2_convert( + self, + acc_S_row: cute.Tensor, + row_max: Float32, + acc_S_row_converted: cute.Tensor, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + minus_row_max_scaled = -row_max * self.scale_log2 + for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2): + acc_S_row[i], acc_S_row[i + 1] = utils.fma_packed_f32x2( + (acc_S_row[i], acc_S_row[i + 1]), + (self.scale_log2, self.scale_log2), + (minus_row_max_scaled, minus_row_max_scaled), + ) + + # for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2): + # acc_S_row[i], acc_S_row[i + 1] = utils.fma_packed_f32x2( + # (acc_S_row[i], acc_S_row[i + 1]), + # (self.scale_log2, self.scale_log2), + # (minus_row_max_scaled, minus_row_max_scaled), + # ) + # acc_S_row[i] = cute.arch.exp2(acc_S_row[i]) + # acc_S_row[i + 1] = cute.arch.exp2(acc_S_row[i + 1]) + + frg_tile = 32 + assert frg_tile % 2 == 0 + frg_cnt = cute.size(acc_S_row) // frg_tile + assert cute.size(acc_S_row) % frg_tile == 0 + acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide( + acc_S_row_converted, cute.make_layout(frg_tile) + ) + for j in cutlass.range_constexpr(frg_cnt): + for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): + # acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = ( + # utils.fma_packed_f32x2( + # (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]), + # (self.scale_log2, self.scale_log2), + # (minus_row_max_scaled, minus_row_max_scaled), + # ) + # ) + # acc_S_row_frg[k, j] = utils.exp2f(acc_S_row_frg[k, j]) + # acc_S_row_frg[k + 1, j] = utils.exp2f(acc_S_row_frg[k + 1, j]) + acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j]) + acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j]) + acc_S_row_converted_frg[None, j].store( + acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) + ) + + +@cute.jit +def floor_if_packed( + q_idx, + qhead_per_kvhead: cutlass.Constexpr[int], +) -> cute.Tensor: + """Convert q_idx to packed format for Pack-GQA.""" + if cutlass.const_expr(qhead_per_kvhead == 1): + return q_idx + return q_idx // qhead_per_kvhead + + +@cute.jit +def apply_score_mod_inner( + score_tensor, + index_tensor, + score_mod: cutlass.Constexpr, + batch_idx, + head_idx, + softmax_scale, + vec_size: cutlass.Constexpr, + qk_acc_dtype: cutlass.Constexpr, + aux_tensors, + fastdiv_mods, + seqlen_info: SeqlenInfoQK, + constant_q_idx: cutlass.Constexpr, + qhead_per_kvhead: cutlass.Constexpr[int] = 1, + transpose_indices: cutlass.Constexpr[bool] = False, +): + """Shared implementation for applying score modification. + + Args: + score_tensor: The scores to modify (acc_S for flash_fwd, tSrS_t2r for sm100) + index_tensor: Index positions (tScS for flash_fwd, tScS_t2r for sm100) + score_mod: The score modification function to apply + batch_idx: Batch index + head_idx: Head index + softmax_scale: Scale to apply + vec_size: Vector size for processing elements + qk_acc_dtype: Data type for accumulator + aux_tensors: Optional aux_tensors for FlexAttention + fastdiv_mods: Tuple of (seqlen_q_divmod, seqlen_k_divmod) for wrapping + seqlen_info: Sequence length info + constant_q_idx: If provided, use this constant for all q_idx values + If None, compute q_idx per-element + qhead_per_kvhead_packgqa: Pack-GQA replication factor. Divide q_idx by this + when greater than 1 so score mods see logical heads. + transpose_indices: If True, swap q_idx/kv_idx in index_tensor (for bwd kernel where S is transposed) + """ + # Index positions in the index_tensor tuple + # Forward: index_tensor[...][0] = q_idx, index_tensor[...][1] = kv_idx + # Backward (transposed): index_tensor[...][0] = kv_idx, index_tensor[...][1] = q_idx + if cutlass.const_expr(transpose_indices): + q_idx_pos = cutlass.const_expr(1) + kv_idx_pos = cutlass.const_expr(0) + else: + q_idx_pos = cutlass.const_expr(0) + kv_idx_pos = cutlass.const_expr(1) + + n_vals = cutlass.const_expr(cute.size(score_tensor.shape)) + score_vec = cute.make_rmem_tensor(vec_size, qk_acc_dtype) + kv_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32) + + # SSA values for batch (constant across all elements) + batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32).broadcast_to((vec_size,)) + + # Handle q_idx based on whether it's constant + q_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32) + + # For Pack-GQA with non-constant q_idx, we need per-element head indices + # since a thread my process multiple query head indices + if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None): + head_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32) + + for i in cutlass.range(0, n_vals, vec_size, unroll_full=True): + for j in cutlass.range(vec_size, unroll_full=True): + score_vec[j] = score_tensor[i + j] * softmax_scale + + # Extract head offset from packed q_idx for Pack-GQA + if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None): + q_idx_packed = index_tensor[i + j][q_idx_pos] + # Building up the logical q_head idx: final_q_head = kv_head * qhead_per_kvhead + (q_physical % qhead_per_kvhead) + q_idx_logical = q_idx_packed // qhead_per_kvhead + head_offset = q_idx_packed - q_idx_logical * qhead_per_kvhead + head_idx_vec[j] = head_idx * qhead_per_kvhead + head_offset + + # If we will do loads we mod, in order to not read OOB + if cutlass.const_expr(aux_tensors is not None and fastdiv_mods is not None): + if cutlass.const_expr(constant_q_idx is None): + seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods + q_idx_floored = floor_if_packed( + index_tensor[i + j][q_idx_pos], qhead_per_kvhead + ) + _, q_idx_wrapped = divmod(q_idx_floored, seqlen_q_divmod) + q_idx_vec[j] = q_idx_wrapped + else: + _, seqlen_k_divmod = fastdiv_mods + + _, kv_idx_wrapped = divmod(index_tensor[i + j][kv_idx_pos], seqlen_k_divmod) + kv_idx_vec[j] = kv_idx_wrapped + else: + # No bounds checking - direct indexing + if constant_q_idx is None: + q_idx_vec[j] = floor_if_packed(index_tensor[i + j][q_idx_pos], qhead_per_kvhead) + kv_idx_vec[j] = index_tensor[i + j][kv_idx_pos] + + # Convert to SSA for score_mod call + score_ssa = score_vec.load() + kv_idx_ssa = kv_idx_vec.load() + if cutlass.const_expr(constant_q_idx is None): + q_idx_ssa = q_idx_vec.load() + else: + # NB we do not apply Pack-GQA division here, as constant_q_idx is assumed to already be logical + q_idx_const = constant_q_idx + q_idx_ssa = utils.scalar_to_ssa(q_idx_const, cutlass.Int32).broadcast_to((vec_size,)) + + # Compute head_idx_ssa: per-element for Pack-GQA with non-constant q_idx, constant otherwise + if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None): + head_idx_ssa = head_idx_vec.load() + else: + head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32).broadcast_to((vec_size,)) + + aux_args = [] + if cutlass.const_expr(aux_tensors is not None): + aux_args = aux_tensors + + post_mod_scores = score_mod( + score_ssa, + batch_idx_ssa, + head_idx_ssa, + q_idx=q_idx_ssa, + kv_idx=kv_idx_ssa, + seqlen_info=seqlen_info, + aux_tensors=aux_args, + ) + + # Write back modified scores + score_vec.store(post_mod_scores) + for j in cutlass.range(vec_size, unroll_full=True): + score_tensor[i + j] = score_vec[j] + + +@cute.jit +def apply_score_mod_bwd_inner( + grad_tensor, + score_tensor, + index_tensor, + score_mod_bwd: cutlass.Constexpr, + batch_idx, + head_idx, + softmax_scale, + vec_size: cutlass.Constexpr, + qk_acc_dtype: cutlass.Constexpr, + aux_tensors, + fastdiv_mods, + seqlen_info, + constant_q_idx: cutlass.Constexpr, + qhead_per_kvhead: cutlass.Constexpr[int] = 1, + transpose_indices: cutlass.Constexpr[bool] = False, +): + """Apply backward score modification (joint graph). + + Args: + grad_tensor: in/out: dlogits rewritten in-place with d(scaled_scores) + score_tensor: pre-mod scores (unscaled QK tile), scaled by softmax_scale internally + index_tensor: Index positions (same as forward) + score_mod_bwd: The backward score modification function (joint graph) + batch_idx: Batch index + head_idx: Head index + softmax_scale: Scale to apply to score_tensor + vec_size: Vector size for processing elements + qk_acc_dtype: Data type for accumulator + aux_tensors: Optional aux_tensors for FlexAttention + fastdiv_mods: Tuple of (seqlen_q_divmod, seqlen_k_divmod) for wrapping + seqlen_info: Sequence length info + constant_q_idx: If provided, use this constant for all q_idx values + qhead_per_kvhead: Pack-GQA replication factor + transpose_indices: If True, swap q_idx/kv_idx in index_tensor + """ + # Index positions in the index_tensor tuple + # Forward: index_tensor[...][0] = q_idx, index_tensor[...][1] = kv_idx + # Backward (transposed): index_tensor[...][0] = kv_idx, index_tensor[...][1] = q_idx + if cutlass.const_expr(transpose_indices): + q_idx_pos = cutlass.const_expr(1) + kv_idx_pos = cutlass.const_expr(0) + else: + q_idx_pos = cutlass.const_expr(0) + kv_idx_pos = cutlass.const_expr(1) + n_vals = cutlass.const_expr(cute.size(grad_tensor.shape)) + grad_vec = cute.make_fragment(vec_size, qk_acc_dtype) + score_vec = cute.make_fragment(vec_size, qk_acc_dtype) + kv_idx_vec = cute.make_fragment(vec_size, cutlass.Int32) + batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32).broadcast_to((vec_size,)) + q_idx_vec = cute.make_fragment(vec_size, cutlass.Int32) + + # For Pack-GQA with non-constant q_idx, we need per-element head indices + if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None): + head_idx_vec = cute.make_fragment(vec_size, cutlass.Int32) + + for i in cutlass.range(0, n_vals, vec_size, unroll_full=True): + for j in cutlass.range(vec_size, unroll_full=True): + grad_vec[j] = grad_tensor[i + j] + # Scale score so joint graph sees same value as forward score_mod + score_vec[j] = score_tensor[i + j] * softmax_scale + + if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None): + q_idx_packed = index_tensor[i + j][q_idx_pos] + q_idx_logical = q_idx_packed // qhead_per_kvhead + head_offset = q_idx_packed - q_idx_logical * qhead_per_kvhead + head_idx_vec[j] = head_idx * qhead_per_kvhead + head_offset + + if cutlass.const_expr(aux_tensors is not None and fastdiv_mods is not None): + if cutlass.const_expr(constant_q_idx is None): + seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods + q_idx_floored = floor_if_packed( + index_tensor[i + j][q_idx_pos], qhead_per_kvhead + ) + _, q_idx_wrapped = divmod(q_idx_floored, seqlen_q_divmod) + q_idx_vec[j] = q_idx_wrapped + else: + _, seqlen_k_divmod = fastdiv_mods + + _, kv_idx_wrapped = divmod(index_tensor[i + j][kv_idx_pos], seqlen_k_divmod) + kv_idx_vec[j] = kv_idx_wrapped + else: + # No bounds checking - direct indexing + if constant_q_idx is None: + q_idx_vec[j] = floor_if_packed(index_tensor[i + j][q_idx_pos], qhead_per_kvhead) + kv_idx_vec[j] = index_tensor[i + j][kv_idx_pos] + + grad_ssa = grad_vec.load() + score_ssa = score_vec.load() + kv_idx_ssa = kv_idx_vec.load() + + if cutlass.const_expr(constant_q_idx is None): + q_idx_ssa = q_idx_vec.load() + else: + q_idx_ssa = utils.scalar_to_ssa(constant_q_idx, cutlass.Int32).broadcast_to((vec_size,)) + + if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None): + head_idx_ssa = head_idx_vec.load() + else: + head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32).broadcast_to((vec_size,)) + + aux_args = [] + if cutlass.const_expr(aux_tensors is not None): + aux_args = aux_tensors + + grad_out_ssa = score_mod_bwd( + grad_ssa, + score_ssa, + batch_idx_ssa, + head_idx_ssa, + q_idx=q_idx_ssa, + kv_idx=kv_idx_ssa, + seqlen_info=seqlen_info, + aux_tensors=aux_args, + ) + + grad_vec.store(grad_out_ssa) + for j in cutlass.range(vec_size, unroll_full=True): + grad_tensor[i + j] = grad_vec[j] diff --git a/flash_attn/cute/testing.py b/flash_attn/cute/testing.py new file mode 100644 index 00000000000..2897e64fc3d --- /dev/null +++ b/flash_attn/cute/testing.py @@ -0,0 +1,423 @@ +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + + +class IndexFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + return torch.gather( + rearrange(input, "b ... -> b (...)"), + 0, + repeat(indices, "z -> z d", d=second_dim), + ).reshape(-1, *other_shape) + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + grad_output = rearrange(grad_output, "b ... -> b (...)") + grad_input = torch.zeros( + [ctx.first_axis_dim, grad_output.shape[1]], + device=grad_output.device, + dtype=grad_output.dtype, + ) + grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis = IndexFirstAxis.apply + + +class IndexPutFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, values, indices, first_axis_dim): + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert values.ndim >= 2 + output = torch.zeros( + first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype + ) + output[indices] = values + return output + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + grad_values = grad_output[indices] + return grad_values, None, None + + +index_put_first_axis = IndexPutFirstAxis.apply + + +def unpad_input(hidden_states, attention_mask, unused_mask=None): + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +def pad_input(hidden_states, indices, batch, seqlen): + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return rearrange(output, "(b s) ... -> b s ...", b=batch) + + +def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False): + assert mode in ["full", "random", "third"] + if mode == "full": + lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) + elif mode == "random": + lengths = torch.randint( + max(0 if zero_lengths else 1, max_seqlen - 20), + max_seqlen + 1, + (batch_size, 1), + device=device, + ) + else: + lengths = torch.randint( + max(0 if zero_lengths else 1, max_seqlen // 3), + max_seqlen + 1, + (batch_size, 1), + device=device, + ) + + if zero_lengths: + for i in range(batch_size): + if i % 5 == 0: + lengths[i] = 0 + lengths[-1] = 0 + padding_mask = ( + repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths + ) + return padding_mask + + +def generate_qkv( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + qv=None, + kvpacked=False, + qkvpacked=False, + query_unused_mask=None, + key_unused_mask=None, +): + assert not (kvpacked and qkvpacked) + batch_size, seqlen_q, nheads, d = q.shape + d_v = v.shape[-1] + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d_v) + if query_unused_mask is not None or key_unused_mask is not None: + assert not kvpacked + assert not qkvpacked + + if query_padding_mask is not None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input( + q, query_padding_mask, query_unused_mask + ) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None + else: + q_unpad = rearrange(q, "b s h d -> (b s) h d") + cu_seqlens_q = torch.arange( + 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device + ) + seqused_q = None + max_seqlen_q = seqlen_q + output_pad_fn = lambda output_unpad: rearrange( + output_unpad, "(b s) h d -> b s h d", b=batch_size + ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None + + if key_padding_mask is not None: + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input( + k, key_padding_mask, key_unused_mask + ) + v_unpad, *_ = unpad_input(v, key_padding_mask, key_unused_mask) + else: + k_unpad = rearrange(k, "b s h d -> (b s) h d") + v_unpad = rearrange(v, "b s h d -> (b s) h d") + cu_seqlens_k = torch.arange( + 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device + ) + seqused_k = None + max_seqlen_k = seqlen_k + + if qkvpacked: + assert (query_padding_mask == key_padding_mask).all() + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + if query_padding_mask is not None: + dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) + else: + dqkv_pad_fn = lambda dqkv_unpad: rearrange( + dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + qkv_unpad.detach().requires_grad_(), + cu_seqlens_q, + max_seqlen_q, + qkv.detach().requires_grad_(), + output_pad_fn, + dqkv_pad_fn, + ) + elif kvpacked: + kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) + kv = torch.stack([k, v], dim=2) + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) + else: + dkv_pad_fn = lambda dkv_unpad: rearrange( + dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + q_unpad.detach().requires_grad_(), + kv_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + kv.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dkv_pad_fn, + ) + else: + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k) + else: + dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) + return ( + q_unpad.detach().requires_grad_(), + k_unpad.detach().requires_grad_(), + v_unpad.detach().requires_grad_(), + qv_unpad.detach() if qv is not None else None, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + k.detach().requires_grad_(), + v.detach().requires_grad_(), + qv.detach() if qv is not None else None, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) + + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(None, None), + sink_token_length=0, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + if window_size[0] is None: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + if window_size[1] is None: + local_mask_left = col_idx > sk + else: + local_mask_left = col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk) + return torch.logical_or( + local_mask_left, + torch.logical_and( + col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length + ), + ) + + +def construct_chunk_mask( + seqlen_q, + seqlen_k, + attention_chunk, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk + return torch.logical_or( + col_idx < col_limit_left_chunk, col_idx >= col_limit_left_chunk + attention_chunk + ) + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=(None, None), + attention_chunk=0, + sink_token_length=0, + learnable_sink: Optional[torch.Tensor] = None, + softcap=0.0, + upcast=True, + reorder_ops=False, + intermediate_dtype=None, +): + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + qv = qv.float() if qv is not None else None + if q_descale is not None: + q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2]) + q = (q.float() * q_descale).to(q.dtype) + qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None + if k_descale is not None: + k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype) + if v_descale is not None: + v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype) + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + dv = v.shape[-1] + softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv) + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if qv is not None: + scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v) + if softcap > 0: + scores = torch.tanh(scores / softcap) * softcap + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + local_mask = None + if window_size[0] is not None or window_size[1] is not None: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + sink_token_length, + query_padding_mask, + key_padding_mask, + key_leftpad=key_leftpad, + device=q.device, + ) + if attention_chunk > 0: + chunk_mask = construct_chunk_mask( + seqlen_q, + seqlen_k, + attention_chunk, + query_padding_mask, + key_padding_mask, + key_leftpad=key_leftpad, + device=q.device, + ) + local_mask = ( + torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask + ) + if local_mask is not None: + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias + if learnable_sink is None: + attention = torch.softmax(scores, dim=-1).to(v.dtype) + else: + scores_fp32 = scores.to(torch.float32) + logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True) + learnable_sink = rearrange(learnable_sink, "h -> h 1 1") + logits_or_sinks_max = torch.maximum(learnable_sink, logits_max) + unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp( + learnable_sink - logits_or_sinks_max + ) + attention = (unnormalized_scores / normalizer).to(v.dtype) + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + if key_padding_mask is not None: + attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) + if local_mask is not None: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + dropout_scaling = 1.0 / (1 - dropout_p) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + if intermediate_dtype is not None: + attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py new file mode 100644 index 00000000000..36a5c6b75ec --- /dev/null +++ b/flash_attn/cute/tile_scheduler.py @@ -0,0 +1,719 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Optional, Tuple +from dataclasses import dataclass, fields + +try: + from typing import override +except ImportError: # Python < 3.12 + from typing_extensions import override + +import cutlass +from cutlass._mlir import ir +import cutlass.cute as cute +from cutlass import Int32, const_expr + +import flash_attn.cute.utils as utils +from flash_attn.cute.fast_math import clz +from cutlass.cute import FastDivmodDivisor + + +class WorkTileInfo(cutlass.utils.WorkTileInfo): + """Altered WorkTileInfo which includes four axes: (block, head, batch, split)""" + + @override + def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo": + assert len(values) == 5 + new_tile_idx = cutlass.new_from_mlir_values(self._tile_idx, values[:-1]) + new_is_valid_tile = cutlass.new_from_mlir_values(self._is_valid_tile, [values[-1]]) + return WorkTileInfo(new_tile_idx, new_is_valid_tile) + + +@dataclass +class ParamsBase: + def __extract_mlir_values__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, cutlass.Constexpr)] + values, self._values_pos = [], [] + for obj in non_constexpr_fields: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, cutlass.Constexpr)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, cutlass.Constexpr) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) + + +@dataclass +class TileSchedulerArguments(ParamsBase): + num_block: Int32 + num_head: Int32 + num_batch: Int32 + num_splits: Int32 + seqlen_k: Int32 + headdim: Int32 + headdim_v: Int32 + total_q: Int32 + tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] + cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + mCuSeqlensQ: Optional[cute.Tensor] = None + mSeqUsedQ: Optional[cute.Tensor] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + element_size: cutlass.Constexpr[int] = 2 + is_persistent: cutlass.Constexpr[bool] = False + lpt: cutlass.Constexpr[bool] = False + is_split_kv: cutlass.Constexpr[bool] = False + head_swizzle: cutlass.Constexpr[bool] = False + + +class SingleTileScheduler: + @dataclass + class Params(ParamsBase): + num_block: Int32 + num_head: Int32 + num_batch: Int32 + num_splits: Int32 + num_splits_divmod: FastDivmodDivisor + is_split_kv: cutlass.Constexpr[bool] = False + cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + + @staticmethod + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "SingleTileScheduler.Params": + return SingleTileScheduler.Params( + args.num_block, + args.num_head, + args.num_batch, + args.num_splits, + FastDivmodDivisor(args.num_splits), + args.is_split_kv, + args.cluster_shape_mn, + ) + + def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None): + self.params = params + self._blk_coord = blk_coord + self._is_first_block = True + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + return SingleTileScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create(params: Params, *, loc=None, ip=None) -> "SingleTileScheduler": + blk_coord = cute.arch.block_idx() + return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + # TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1) + assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported" + return ( + cute.round_up(params.num_block, params.cluster_shape_mn[0]), + params.num_head * params.num_splits, + params.num_batch, + ) + + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + block_idx, head_idx, batch_idx = self._blk_coord + if const_expr(self.params.is_split_kv): + head_idx, split_idx = divmod(head_idx, self.params.num_splits_divmod) + else: + split_idx = Int32(0) + return WorkTileInfo( + (block_idx, head_idx, batch_idx, split_idx), + self._is_first_block, + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + self._is_first_block = False + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._blk_coord]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip([self.params, self._blk_coord], self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return SingleTileScheduler(*(tuple(obj_list)), loc=self._loc) + + +class StaticPersistentTileScheduler: + @dataclass + class Params(ParamsBase): + num_block_divmod: FastDivmodDivisor + num_head_divmod: FastDivmodDivisor + total_blocks: Int32 + + @staticmethod + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "StaticPersistentTileScheduler.Params": + total_blocks = args.num_block * args.num_head * args.num_batch + return StaticPersistentTileScheduler.Params( + FastDivmodDivisor(args.num_block), FastDivmodDivisor(args.num_head), total_blocks + ) + + def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): + self.params = params + self._tile_idx = tile_idx + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create(params: Params, *, loc=None, ip=None) -> "StaticPersistentTileScheduler": + tile_idx = cute.arch.block_idx()[0] + return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + hardware_info = cutlass.utils.HardwareInfo() + sm_count = hardware_info.get_device_multiprocessor_count() + return (cutlass.min(sm_count, params.total_blocks), Int32(1), Int32(1)) + + # @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_divmod) + batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod) + is_valid = self._tile_idx < self.params.total_blocks + # if cute.arch.thread_idx()[0] == 0: + # cute.printf("TileScheduler: tile_idx=%d, hn_idx=%d, block_idx=%d, batch_idx=%d, head_idx=%d, is_valid=%d", self._tile_idx, hn_idx, block_idx, batch_idx, head_idx, is_valid) + return WorkTileInfo( + (Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + self._tile_idx += cute.arch.grid_dim()[0] + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._tile_idx]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.params, self._tile_idx], + self._values_pos, + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return StaticPersistentTileScheduler(*(tuple(obj_list)), loc=self._loc) + + +class SingleTileLPTScheduler: + @dataclass + class Params(ParamsBase): + total_blocks: Int32 + num_splits: Int32 + num_block: Int32 + l2_minor: Int32 + num_block_divmod: FastDivmodDivisor + num_head_divmod: FastDivmodDivisor + l2_minor_divmod: FastDivmodDivisor + l2_major_divmod: FastDivmodDivisor + l2_minor_residual_divmod: FastDivmodDivisor + num_hb_quotient: Int32 + is_split_kv: cutlass.Constexpr[bool] = False + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "SingleTileLPTScheduler.Params": + # cute.printf(args.num_block, args.num_head, args.num_batch, args.seqlen_k, args.headdim, args.headdim_v, args.total_q, args.tile_shape_mn, args.qhead_per_kvhead_packgqa, args.element_size) + size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size + size_one_head = size_one_kv_head + size_l2 = 50 * 1024 * 1024 # 40 MB for K & V + # Swizzle is the size of each "section". Round swizzle to a power of 2 + # Need to be careful about the case where only one head will fit + # swizzle is how many heads can fit in L2 + # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head) + # Seems faster if swizzle if a power of 2 + log2_floor = lambda n: 31 - clz(n) + swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) + # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head) + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + num_hb_quotient = (args.num_head * args.num_batch) // swizzle + num_hb_remainder = (args.num_head * args.num_batch) % swizzle + return SingleTileLPTScheduler.Params( + total_blocks=args.num_block * args.num_head * args.num_batch, + num_block=args.num_block, + l2_minor=Int32(swizzle), + num_block_divmod=FastDivmodDivisor(args.num_block), + num_head_divmod=FastDivmodDivisor(args.num_head), + l2_minor_divmod=FastDivmodDivisor(swizzle), + l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block), + l2_minor_residual_divmod=FastDivmodDivisor( + max(num_hb_remainder, 1) + ), # don't divide by 0 + num_hb_quotient=Int32(num_hb_quotient), + num_splits=args.num_splits, + is_split_kv=args.is_split_kv, + ) + + def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None): + self.params = params + self._tile_idx = tile_idx + self._split_idx = split_idx + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + return SingleTileLPTScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + @cute.jit + def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTScheduler": + tile_idx, split_idx, _ = cute.arch.block_idx() + return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + return (params.total_blocks, params.num_splits, Int32(1)) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + params = self.params + # Implement LPT scheduling coordinate calculation + bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod) + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + block, bidhb_residual = 0, 0 + if bidhb < params.num_hb_quotient: + block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod) + else: + block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod) + bidhb_actual = bidhb * params.l2_minor + bidhb_residual + batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod) + # Longest-processing-time-first + block = params.num_block - 1 - block + is_valid = self._tile_idx < params.total_blocks + return WorkTileInfo( + (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + # Single tile scheduler - set to invalid tile_idx to indicate no more work + self._tile_idx = self.params.total_blocks + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._tile_idx, self._split_idx]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip([self.params, self._tile_idx, self._split_idx], self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return self.__class__(*(tuple(obj_list)), loc=self._loc) + + +class SingleTileLPTBwdScheduler: + @dataclass + class Params(ParamsBase): + total_blocks: Int32 + num_block: Int32 + l2_minor: Int32 + num_head_divmod: FastDivmodDivisor + l2_minor_divmod: FastDivmodDivisor + l2_major_divmod: FastDivmodDivisor + l2_minor_residual_divmod: FastDivmodDivisor + num_hb_quotient: Int32 + cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + spt: cutlass.Constexpr[bool] = True + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "SingleTileLPTBwdScheduler.Params": + size_l2 = 50 * 1024 * 1024 + size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size + # size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4 + size_one_dqaccum_head = 0 + size_one_head = size_one_qdo_head + size_one_dqaccum_head + log2_floor = lambda n: 31 - clz(n) + swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) + # swizzle = 8 + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + num_hb_quotient = (args.num_head * args.num_batch) // swizzle + num_hb_remainder = (args.num_head * args.num_batch) % swizzle + num_block = cute.ceil_div(args.num_block, args.cluster_shape_mn[0]) + return SingleTileLPTBwdScheduler.Params( + total_blocks=(num_block * args.cluster_shape_mn[0]) + * args.num_head + * args.num_batch, + num_block=num_block, + l2_minor=Int32(swizzle), + num_head_divmod=FastDivmodDivisor(args.num_head), + l2_minor_divmod=FastDivmodDivisor(swizzle), + l2_major_divmod=FastDivmodDivisor(swizzle * num_block), + l2_minor_residual_divmod=FastDivmodDivisor( + max(num_hb_remainder, 1) + ), # don't divide by 0 + num_hb_quotient=Int32(num_hb_quotient), + cluster_shape_mn=args.cluster_shape_mn, + spt=args.lpt, + ) + + def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): + self.params = params + self._tile_idx = tile_idx + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + return SingleTileLPTBwdScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + @cute.jit + def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTBwdScheduler": + tile_idx = cute.arch.block_idx()[0] + return SingleTileLPTBwdScheduler(params, tile_idx, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + return (params.total_blocks, Int32(1), Int32(1)) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + cluster_idx = self._tile_idx // self.params.cluster_shape_mn[0] + params = self.params + # Implement LPT scheduling coordinate calculation + bidhb, l2_mod = divmod(cluster_idx, params.l2_major_divmod) + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + block, bidhb_residual = 0, 0 + if bidhb < params.num_hb_quotient: + block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod) + else: + block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod) + bidhb_actual = bidhb * params.l2_minor + bidhb_residual + batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod) + is_valid = self._tile_idx < params.total_blocks + bidx_in_cluster = cute.arch.block_in_cluster_idx() + block = block * params.cluster_shape_mn[0] + bidx_in_cluster[0] + if cutlass.const_expr(params.spt): + block = params.num_block - 1 - block + return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + # Single tile scheduler - set to invalid tile_idx to indicate no more work + self._tile_idx = self.params.total_blocks + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._tile_idx]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip([self.params, self._tile_idx], self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return self.__class__(*(tuple(obj_list)), loc=self._loc) + + +class SingleTileVarlenScheduler: + @dataclass + class Params(ParamsBase): + num_head: Int32 + num_batch: Int32 + total_q: Int32 + num_splits: Int32 + max_kvblock_in_l2: Int32 + tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] + mCuSeqlensQ: Optional[cute.Tensor] = None + mSeqUsedQ: Optional[cute.Tensor] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + lpt: cutlass.Constexpr[bool] = False + is_split_kv: cutlass.Constexpr[bool] = False + head_swizzle: cutlass.Constexpr[bool] = False + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "SingleTileVarlenScheduler.Params": + size_l2 = 50 * 1024 * 1024 # 50 MB for K & V + max_kvblock_in_l2 = size_l2 // ( + (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] + ) + assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, ( + "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" + ) + return SingleTileVarlenScheduler.Params( + num_head=args.num_head, + num_batch=args.num_batch, + total_q=args.total_q, + num_splits=args.num_splits, + max_kvblock_in_l2=max_kvblock_in_l2, + tile_shape_mn=args.tile_shape_mn, + mCuSeqlensQ=args.mCuSeqlensQ, + mSeqUsedQ=args.mSeqUsedQ, + qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa, + lpt=args.lpt, + is_split_kv=args.is_split_kv, + head_swizzle=args.head_swizzle, + ) + + def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None): + self.params = params + self._tile_idx = tile_idx + self._split_idx = split_idx + self._is_first_block = True + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + return SingleTileVarlenScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create(params: Params, *, loc=None, ip=None) -> "SingleTileVarlenScheduler": + tile_idx, split_idx, _ = cute.arch.block_idx() + return SingleTileVarlenScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + total_blocks_max = ( + params.total_q + params.num_batch * (params.tile_shape_mn[0] - 1) + ) // params.tile_shape_mn[0] + return (total_blocks_max * params.num_head, params.num_splits, Int32(1)) + + @cute.jit + def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: + params = self.params + batch_idx = lane + bidb_start + if cutlass.const_expr(params.mSeqUsedQ is not None): + seqlen = Int32(0) + if batch_idx < params.num_batch: + seqlen = params.mSeqUsedQ[batch_idx] + else: + assert params.mCuSeqlensQ is not None + cur_cu_seqlen = Int32(0) + if batch_idx <= params.num_batch: + cur_cu_seqlen = params.mCuSeqlensQ[batch_idx] + next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) + seqlen = next_cu_seqlen - cur_cu_seqlen + if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1): + seqlen *= params.qhead_per_kvhead_packgqa + return ( + cute.ceil_div(seqlen, params.tile_shape_mn[0]) + if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1 + else Int32(0) + ) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + params = self.params + lane_idx = cute.arch.lane_idx() + num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0) + num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) + # Total number of blocks for the next 31 batches + m_blocks_in_group = cute.arch.shuffle_sync(num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1) + # Same for all lanes + group_end_tile = m_blocks_in_group * params.num_head + # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d", self._tile_idx, group_end_tile, num_m_blocks, num_m_blocks_cumulative, m_blocks_in_group) + block, head_idx, batch_idx = Int32(0), Int32(0), Int32(0) + next_tile_idx = self._tile_idx + while group_end_tile <= next_tile_idx: + batch_idx += cute.arch.WARP_SIZE - 1 + if batch_idx >= params.num_batch: + batch_idx = Int32(params.num_batch) + group_end_tile = next_tile_idx + 1 + else: + num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx) + num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) + m_blocks_in_group = cute.arch.shuffle_sync( + num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1 + ) + group_end_tile += m_blocks_in_group * params.num_head + is_valid = False + if batch_idx >= params.num_batch: + block, head_idx, batch_idx = Int32(0), Int32(0), Int32(params.num_batch) + else: + group_start_tile = group_end_tile - m_blocks_in_group * params.num_head + # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, batch_idx = %d", self._tile_idx, group_end_tile, num_m_blocks, batch_idx) + # The next problem to process is the first one that does not have ending tile position + # that is greater than or equal to tile index. + batch_idx_in_group = cute.arch.popc( + cute.arch.vote_ballot_sync( + group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx + ) + ) + batch_idx += batch_idx_in_group + num_m_blocks_prev_lane = ( + 0 + if batch_idx_in_group == 0 + else cute.arch.shuffle_sync(num_m_blocks_cumulative, batch_idx_in_group - 1) + ) + num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group) + mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * params.num_head + if cutlass.const_expr(params.lpt or params.head_swizzle): + # This is a version of the SingleTileLPTScheduler, complicated by the fact that + # the seqlen can vary per batch. + # TODO: is there any case where num_m_blocks is 0? + # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here + num_n_blocks = ( + num_m_blocks + * params.tile_shape_mn[0] + // params.qhead_per_kvhead_packgqa + // params.tile_shape_mn[1] + ) + # nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head) + # Seems faster to have this be a power of 2 + nheads_in_l2 = ( + 16 + if num_n_blocks * 16 <= params.max_kvblock_in_l2 + else ( + 8 + if num_n_blocks * 8 <= params.max_kvblock_in_l2 + else ( + 4 + if num_n_blocks * 4 <= params.max_kvblock_in_l2 + else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1) + ) + ) + ) + nheads_in_l2 = min(nheads_in_l2, params.num_head) + mh_in_l2 = nheads_in_l2 * num_m_blocks + section_idx = mh_block // mh_in_l2 + l2_mod = mh_block - section_idx * mh_in_l2 + # Deal with tail section + nheads_in_this_section = ( + nheads_in_l2 + if nheads_in_l2 * (section_idx + 1) <= params.num_head + else params.num_head - section_idx * nheads_in_l2 + ) + block = l2_mod // nheads_in_this_section + head_idx_residual = l2_mod - block * nheads_in_this_section + head_idx = section_idx * nheads_in_l2 + head_idx_residual + if cutlass.const_expr(params.lpt): + block = num_m_blocks - 1 - block + else: + head_idx = mh_block // num_m_blocks + block = mh_block - head_idx * num_m_blocks + is_valid = self._is_first_block and batch_idx < params.num_batch + # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid) + split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0) + return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + # Single tile scheduler - set to invalid tile_idx to indicate no more work + self._is_first_block = False + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._tile_idx, self._split_idx]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.params, self._tile_idx, self._split_idx], + self._values_pos, + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return SingleTileVarlenScheduler(*(tuple(obj_list)), loc=self._loc) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py new file mode 100644 index 00000000000..f31d85c5d44 --- /dev/null +++ b/flash_attn/cute/utils.py @@ -0,0 +1,859 @@ +# Copyright (c) 2025, Tri Dao. + +import math +import hashlib +import inspect +import re +from typing import Type, Callable, Optional, Tuple, overload +from functools import partial + +import cutlass +import cutlass.cute as cute + +from cutlass import Float32, const_expr +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import nvvm, llvm +from cutlass.cute.runtime import from_dlpack + + +# cute.arch.{fma,mul,add}_packed_f32x2 uses RZ rounding mode by default +fma_packed_f32x2 = partial(cute.arch.fma_packed_f32x2, rnd=nvvm.RoundingModeKind.RN) +mul_packed_f32x2 = partial(cute.arch.mul_packed_f32x2, rnd=nvvm.RoundingModeKind.RN) +add_packed_f32x2 = partial(cute.arch.add_packed_f32x2, rnd=nvvm.RoundingModeKind.RN) +sub_packed_f32x2 = partial( + cute.arch.calc_packed_f32x2_op, + src_c=None, + calc_func=nvvm.sub_packed_f32x2, + rnd=nvvm.RoundingModeKind.RN, +) + + +def hash_callable(func: Callable, set_cute_hash=True) -> str: + """Hash a callable based on the source code or bytecode and closure values. + + Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__`` + attribute, that value is returned immediately. Code-generation backends such + as Inductor can set this attribute to avoid expensive runtime hashing. + + set_cute_hash: whether or not to set func.__cute_hash__ if not present + """ + if hasattr(func, "__cute_hash__"): + return func.__cute_hash__ + + # Unwrap decorated functions (e.g., cute.jit wrappers). + if hasattr(func, "__wrapped__"): + base_func = func.__wrapped__ + if hasattr(base_func, "__cute_hash__"): + return base_func.__cute_hash__ + func = base_func + + try: + data = inspect.getsource(func).encode() + except (OSError, TypeError): + if hasattr(func, "__code__") and func.__code__ is not None: + data = func.__code__.co_code + else: + data = repr(func).encode() + + hasher = hashlib.sha256(data) + + if hasattr(func, "__closure__") and func.__closure__ is not None: + for idx, cell in enumerate(func.__closure__): + cell_value = cell.cell_contents + hasher.update(repr(cell_value).encode()) + + hash = hasher.hexdigest() + + if set_cute_hash: + func.__cute_hash__ = hash + + return hash + + +def create_softcap_scoremod(softcap_val): + inv_softcap = 1.0 / softcap_val + + @cute.jit + def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, aux_tensors): + scores = acc_S_SSA * inv_softcap + return scores * cute.math.tanh(scores, fastmath=True) + + return scoremod_premask_fn + + +def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor: + return ( + from_dlpack(x, assumed_align=alignment) + .mark_layout_dynamic(leading_dim=leading_dim) + .mark_compact_shape_dynamic( + mode=leading_dim, stride_order=x.dim_order(), divisibility=divisibility + ) + ) + + +def convert_from_dlpack_leading_static( + x, leading_dim, alignment=16, static_modes=None, stride_order=None +) -> cute.Tensor: + if stride_order is None: + stride_order = x.dim_order() + x_ = from_dlpack(x, assumed_align=alignment) + for i in range(x.ndim): + if i != leading_dim and (static_modes is None or i not in static_modes): + x_ = x_.mark_compact_shape_dynamic(mode=i, stride_order=stride_order) + return x_ + + +def make_tiled_copy_A( + copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.TiledCopy: + if const_expr(swapAB): + return cute.make_tiled_copy_B(copy_atom, tiled_mma) + else: + return cute.make_tiled_copy_A(copy_atom, tiled_mma) + + +def make_tiled_copy_B( + copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.TiledCopy: + if const_expr(swapAB): + return cute.make_tiled_copy_A(copy_atom, tiled_mma) + else: + return cute.make_tiled_copy_B(copy_atom, tiled_mma) + + +def mma_make_fragment_A( + smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.Tensor: + if const_expr(swapAB): + return mma_make_fragment_B(smem, thr_mma) + else: + return thr_mma.make_fragment_A(thr_mma.partition_A(smem)) + + +def mma_make_fragment_B( + smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.Tensor: + if const_expr(swapAB): + return mma_make_fragment_A(smem, thr_mma) + else: + return thr_mma.make_fragment_B(thr_mma.partition_B(smem)) + + +def get_smem_store_atom( + arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False +) -> cute.CopyAtom: + if const_expr(arch < 90 or element_type.width != 16): + return cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + element_type, + num_bits_per_copy=2 * element_type.width, + ) + else: + return cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4), + element_type, + ) + + +@cute.jit +def warp_reduce( + val: cute.TensorSSA | cute.Numeric, + op: Callable, + width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, +) -> cute.TensorSSA | cute.Numeric: + if const_expr(isinstance(val, cute.TensorSSA)): + res = cute.make_fragment(val.shape, val.dtype) + res.store(val) + for i in cutlass.range_constexpr(cute.size(val.shape)): + res[i] = warp_reduce(res[i], op, width) + return res.load() + else: + for i in cutlass.range_constexpr(int(math.log2(width))): + val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i)) + return val + + +def convert_layout_acc_mn(acc_layout: cute.Layout, transpose: bool = False) -> cute.Layout: + """ + For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...). + For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...). + """ + acc_layout_col_major = cute.make_layout(acc_layout.shape) + shape = ( + (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M + ( + acc_layout_col_major.shape[0][0], + *acc_layout_col_major.shape[0][2:], + acc_layout_col_major.shape[2], + ), # MMA_N + *acc_layout_col_major.shape[3:], + ) + stride = ( + (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M + ( + acc_layout_col_major.stride[0][0], + *acc_layout_col_major.stride[0][2:], + acc_layout_col_major.stride[2], + ), # MMA_N + *acc_layout_col_major.stride[3:], + ) + if const_expr(transpose): + shape = (shape[1], shape[0], *shape[2:]) + stride = (stride[1], stride[0], *stride[2:]) + acc_layout_mn = cute.make_layout(shape, stride=stride) + return cute.composition(acc_layout, acc_layout_mn) + + +def make_acc_tensor_mn_view(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose)) + + +@cute.jit +def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout: + # For back to back gemm, convert layout of acc0 to gemm 1 accept layout. + # For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + # For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) + # TODO: Sm90 FP8 + if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90 + l = cute.logical_divide( + acc_layout, ((None, None, 2), None, None) + ) # ((2, 2, (2, N / 16)), MMA_M, MMA_N) + rA_mma_view = cute.make_layout( + ( + (l.shape[0][0], l.shape[0][1], l.shape[0][2][0]), + l.shape[1], + (l.shape[0][2][1], l.shape[2]), + ), + stride=( + (l.stride[0][0], l.stride[0][1], l.stride[0][2][0]), + l.stride[1], + (l.stride[0][2][1], l.stride[2]), + ), + ) + else: # Sm80 + # (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2)) + l = cute.logical_divide(acc_layout, (None, None, 2)) + rA_mma_view = cute.make_layout( + ( + (l.shape[0], l.shape[2][0]), + l.shape[1], + l.shape[2][1], + ), + stride=( + (l.stride[0], l.stride[2][0]), + l.stride[1], + l.stride[2][1], + ), + ) + return rA_mma_view + + +def make_acc_tensor_frgA_view(acc: cute.Tensor) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_frgA(acc.layout)) + + +def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor: + return cute.make_tensor(a.iterator, cute.select(a.layout, mode)) + + +def transpose_view(a: cute.Tensor) -> cute.Tensor: + """Transpose the first two dimensions of a tensor on smem.""" + shape = (a.shape[1], a.shape[0], *a.shape[2:]) + order = (1, 0, *range(2, cute.rank(a))) + return cute.composition(a, cute.make_ordered_layout(shape, order=order)) + # stride = (a.layout.stride[1], a.layout.stride[0], *a.layout.stride[2:]) + # return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride)) + + +def parse_swizzle_from_pointer(ptr: cute.Pointer) -> cute.Swizzle: + """Extract swizzle parameters from a pointer's swizzle_type. + + The swizzle_type string has the form '!cute.swizzle<"S">' where + b, m, s are the swizzle parameters (bits, base, shift). + + Returns: + A cute.Swizzle object constructed from the extracted parameters + + Raises: + ValueError: If the swizzle_type string cannot be parsed + """ + # Ideally there should be a better API to get swizzle parameters, but we'll just parse + # the string here. + swizzle_str = str(ptr.type.swizzle_type) + # Extract the inner part "S" + match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str) + if match: + b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3)) + return cute.make_swizzle(b, m, s) + else: + raise ValueError(f"Could not parse swizzle_type: {swizzle_str}") + + +@cute.jit +def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32: + """exp2f calculation for both vector and scalar. + :param x: input value + :type x: cute.TensorSSA or Float32 + :return: exp2 value + :rtype: cute.TensorSSA or Float32 + """ + if const_expr(isinstance(x, cute.TensorSSA)): + res = cute.make_fragment(x.shape, Float32) + res.store(x) + for i in cutlass.range_constexpr(cute.size(x.shape)): + res[i] = cute.arch.exp2(res[i]) + return res.load() + else: + return cute.arch.exp2(x) + + +@dsl_user_op +def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32: + return Float32( + llvm.inline_asm( + T.f32(), + [Float32(a).ir_value(loc=loc, ip=ip)], + "lg2.approx.ftz.f32 $0, $1;", + "=f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def logf(a: float | Float32, *, loc=None, ip=None) -> Float32: + return log2f(a, loc=loc, ip=ip) * math.log(2.0) + + +@dsl_user_op +def fmax( + a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None +) -> Float32: + return Float32( + nvvm.fmax( + T.f32(), + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, + loc=loc, + ip=ip, + ) + ) + + +@cute.jit +def fmax_reduce( + x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 +) -> Float32: + if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): + # if const_expr(init_val is None): + # init_val = -cutlass.Float32.if + # return x.reduce(cute.ReductionOp.MAX, init_val, 0) + res = cute.make_fragment(x.shape, Float32) + res.store(x) + # local_max = [res[0], res[1]] + # for i in cutlass.range_constexpr(2, cute.size(x.shape), 2): + # local_max[0] = fmax(local_max[0], res[i + 0]) + # local_max[1] = fmax(local_max[1], res[i + 1]) + # local_max[0] = fmax(local_max[0], local_max[1]) + # return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val) + local_max = [res[0], res[1], res[2], res[3]] + for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): + local_max[0] = fmax(local_max[0], res[i + 0]) + local_max[1] = fmax(local_max[1], res[i + 1]) + local_max[2] = fmax(local_max[2], res[i + 2]) + local_max[3] = fmax(local_max[3], res[i + 3]) + local_max[0] = fmax(local_max[0], local_max[1]) + local_max[2] = fmax(local_max[2], local_max[3]) + local_max[0] = fmax(local_max[0], local_max[2]) + return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val) + else: + # [2025-06-15] x.reduce only seems to use 50% 3-input max and 50% 2-input max + # We instead force the 3-input max. + res = cute.make_fragment(x.shape, Float32) + res.store(x) + local_max_0 = ( + fmax(init_val, res[0], res[1]) + if const_expr(init_val is not None) + else fmax(res[0], res[1]) + ) + local_max = [ + local_max_0, + fmax(res[2], res[3]), + fmax(res[4], res[5]), + fmax(res[6], res[7]), + ] + for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): + local_max[0] = fmax(local_max[0], res[i], res[i + 1]) + local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3]) + local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5]) + local_max[3] = fmax(local_max[3], res[i + 6], res[i + 7]) + local_max[0] = fmax(local_max[0], local_max[1]) + return fmax(local_max[0], local_max[2], local_max[3]) + + +@cute.jit +def fadd_reduce( + x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 +) -> Float32: + if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): + if const_expr(init_val is None): + init_val = Float32.zero + return x.reduce(cute.ReductionOp.ADD, init_val, 0) + # res = cute.make_fragment(x.shape, Float32) + # res.store(x) + # local_sum = [res[0], res[1], res[2], res[3]] + # for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): + # local_sum[0] += res[i + 0] + # local_sum[1] += res[i + 1] + # local_sum[2] += res[i + 2] + # local_sum[3] += res[i + 3] + # local_sum[0] += local_sum[1] + # local_sum[2] += local_sum[3] + # local_sum[0] += local_sum[2] + # return local_sum[0] if const_expr(init_val is None) else local_sum[0] + init_val + else: + res = cute.make_fragment(x.shape, Float32) + res.store(x) + local_sum_0 = ( + add_packed_f32x2((init_val, 0.0), (res[0], res[1])) + # add_packed_f32x2((init_val / 2, init_val / 2), (res[0], res[1])) + if const_expr(init_val is not None) + else (res[0], res[1]) + ) + local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])] + for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): + local_sum[0] = add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1])) + local_sum[1] = add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3])) + local_sum[2] = add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5])) + local_sum[3] = add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7])) + local_sum[0] = add_packed_f32x2(local_sum[0], local_sum[1]) + local_sum[2] = add_packed_f32x2(local_sum[2], local_sum[3]) + local_sum[0] = add_packed_f32x2(local_sum[0], local_sum[2]) + return local_sum[0][0] + local_sum[0][1] + + +@dsl_user_op +def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None: + # gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() + # # cache_hint = cutlass.Int64(0x12F0000000000000) + # llvm.inline_asm( + # None, + # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip)], + # # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()], + # "red.global.add.f32 [$0], $1;", + # # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;", + # # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;", + # "l,f", + # # "l,f,l", + # has_side_effects=True, + # is_align_stack=False, + # asm_dialect=llvm.AsmDialect.AD_ATT, + # ) + nvvm.atomicrmw( + res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, a=Float32(a).ir_value() + ) + + +@dsl_user_op +def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: + return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) + + +@dsl_user_op +def elem_pointer_i64(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: + flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord)) + flat_stride = cute.flatten_to_tuple(x.stride) + assert len(flat_coord_i64) == len(flat_stride), ( + "Coordinate and stride must have the same length" + ) + offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride)) + # HACK: we assume that applying the offset does not change the pointer alignment + byte_offset = offset * x.element_type.width // 8 + return cute.make_ptr( + x.element_type, + x.iterator.toint() + byte_offset, + x.memspace, + assumed_align=x.iterator.alignment, + ) + + +@cute.jit +def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: + # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if" + tApA = cute.make_fragment( + cute.make_layout( + (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), + stride=(cute.size(tAcA, mode=[2]), 0, 1), + ), + cutlass.Boolean, + ) + for rest_v in cutlass.range_constexpr(tApA.shape[0]): + for rest_k in cutlass.range_constexpr(tApA.shape[2]): + tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) + return tApA + + +def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32: + warp_group_idx = cute.arch.thread_idx()[0] // 128 + if const_expr(sync): + warp_group_idx = cute.arch.make_warp_uniform(warp_group_idx) + return warp_group_idx + + +# @dsl_user_op +# def warp_vote_any_lt(a: float | Float32, b: float | Float32, *, loc=None, ip=None) -> cutlass.Boolean: +# mask = cutlass.Int32(-1) +# return cutlass.Boolean( +# llvm.inline_asm( +# T.i32(), +# [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), mask.ir_value(loc=loc, ip=ip)], +# ".pred p1, p2;\n" +# "setp.lt.f32 p1, $1, $2;\n" +# "vote.sync.any.pred p2, p1, $3;\n" +# "selp.u32 $0, 1, 0, p2;", +# # "selp.u32 $0, 1, 0, p1;", +# "=r,f,f,r", +# has_side_effects=False, +# is_align_stack=False, +# asm_dialect=llvm.AsmDialect.AD_ATT, +# ) +# ) + + +@cute.jit +def shuffle_sync( + value: cute.Numeric, + offset: cute.typing.Int, + width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, +) -> cute.Numeric: + assert value.width % 32 == 0, "value type must be a multiple of 32 bits" + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = mask << 8 | clamp + # important: need stride 1 and not 0 for recast_tensor to work + val = cute.make_rmem_tensor(cute.make_layout((1,), stride=(1,)), type(value)) + val[0] = value + val_i32 = cute.recast_tensor(val, cutlass.Int32) + for i in cutlass.range_constexpr(cute.size(val_i32)): + val_i32[i] = cute.arch.shuffle_sync(val_i32[i], offset, mask_and_clamp=mask_and_clamp) + return val[0] + + +@dsl_user_op +def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32: + return cutlass.Uint32( + llvm.inline_asm( + T.i32(), + [ + cutlass.Uint32(val).ir_value(loc=loc, ip=ip), + cutlass.Uint32(shift).ir_value(loc=loc, ip=ip), + ], + "shr.s32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@cute.jit +def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32: + if const_expr(lane is None): + lane = cute.arch.lane_idx() + # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, val = %d", cute.arch.thread_idx()[0] % 32, val) + for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))): + offset = 1 << i + # Very important that we set mask_and_clamp to 0 + partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0) + if lane >= offset: + val += partial_sum + # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, partial_sum = %d, val = %d", cute.arch.thread_idx()[0] % 32, partial_sum, val) + return val + + +@dsl_user_op +def cvt_f16x2_f32( + a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None +) -> cutlass.Int32: + assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16" + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)], + f"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;", + "=r,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@overload +def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ... + + +@overload +def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ... + + +@cute.jit +def cvt_f16(src: cute.Tensor, dst_or_dtype): + """Convert Float32 tensor to Float16/BFloat16. + + Args: + src: Source tensor with Float32 element type + dst_or_dtype: Either a destination tensor or a dtype (Float16/BFloat16) + + Returns: + None if dst is a tensor, or a new tensor if dtype is provided + """ + if const_expr(isinstance(dst_or_dtype, type)): + # dtype variant: create new tensor and call the tensor variant + dtype = dst_or_dtype + dst = cute.make_fragment(src.shape, dtype) + cvt_f16(src, dst) + return dst + else: + # tensor variant: write to dst + dst = dst_or_dtype + assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size" + assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements" + assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], ( + "dst must be BFloat16 or Float16" + ) + assert src.element_type is Float32, "src must be Float32" + dst_i32 = cute.recast_tensor(dst, cutlass.Int32) + assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape) + for i in cutlass.range_constexpr(cute.size(dst_i32)): + dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type) + + +@dsl_user_op +@cute.jit +def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Float32: + deg = len(poly) - 1 + out = poly[deg] + for i in cutlass.range_constexpr(deg - 1, -1, -1): + out = out * x + poly[i] + return out + + +@dsl_user_op +@cute.jit +def evaluate_polynomial_2( + x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None +) -> Tuple[Float32, Float32]: + deg = len(poly) - 1 + out = (poly[deg], poly[deg]) + for i in cutlass.range_constexpr(deg - 1, -1, -1): + out = fma_packed_f32x2(out, (x, y), (poly[i], poly[i])) + return out + + +@dsl_user_op +def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None) -> Float32: + # There's probably a way to call llvm or nvvm to do this instead of ptx + return cutlass.Float32( + llvm.inline_asm( + T.f32(), + [Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)], + "add.rm.ftz.f32 $0, $1, $2;", + "=f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=None) -> Float32: + return cutlass.Float32( + llvm.inline_asm( + T.f32(), + [ + Float32(x_rounded).ir_value(loc=loc, ip=ip), + Float32(frac_ex2).ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t" + "mov.b32 x_rounded_i, $1;\n\t" + "mov.b32 frac_ex_i, $2;\n\t" + "shl.b32 x_rounded_e, x_rounded_i, 23;\n\t" + # add.u32 generates IMAD instruction and add.s32 generates LEA instruction + # IMAD uses the FMA pipeline and LEA uses the ALU pipeline, afaik + "add.s32 out_i, x_rounded_e, frac_ex_i;\n\t" + "mov.b32 $0, out_i;\n\t" + "}\n", + "=f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def ex2_emulation(x: Float32, *, loc=None, ip=None) -> Float32: + # We assume x <= 127.0 + poly_ex2_deg3 = ( + 1.0, + 0.695146143436431884765625, + 0.227564394474029541015625, + 0.077119089663028717041015625, + ) + fp32_round_int = float(2**23 + 2**22) + x_clamped = cute.arch.fmax(x, -127.0) + # We want to round down here, so that the fractional part is in [0, 1) + x_rounded = add_round_down(x_clamped, fp32_round_int, loc=loc, ip=ip) + # The integer floor of x is now in the last 8 bits of x_rounded + # We assume the next 2 ops round to nearest even. The rounding mode is important. + x_rounded_back = x_rounded - fp32_round_int + x_frac = x_clamped - x_rounded_back + x_frac_ex2 = evaluate_polynomial(x_frac, poly_ex2_deg3, loc=loc, ip=ip) + return combine_int_frac_ex2(x_rounded, x_frac_ex2, loc=loc, ip=ip) + + +# TODO: check that the ex2_emulation_2 produces the same SASS as the ptx version +@dsl_user_op +def ex2_emulation_2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]: + # We assume x <= 127.0 and y <= 127.0 + poly_ex2_deg3 = ( + 1.0, + 0.695146143436431884765625, + 0.227564394474029541015625, + 0.077119089663028717041015625, + ) + fp32_round_int = float(2**23 + 2**22) + xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0)) + # We want to round down here, so that the fractional part is in [0, 1) + xy_rounded = cute.arch.add_packed_f32x2( + xy_clamped, (fp32_round_int, fp32_round_int), rnd=nvvm.RoundingModeKind.RM + ) + # The integer floor of x & y are now in the last 8 bits of xy_rounded + # We want the next 2 ops to round to nearest even. The rounding mode is important. + xy_rounded_back = sub_packed_f32x2(xy_rounded, (fp32_round_int, fp32_round_int)) + xy_frac = sub_packed_f32x2(xy_clamped, xy_rounded_back) + xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, poly_ex2_deg3, loc=loc, ip=ip) + x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip) + y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip) + return x_out, y_out + + +@dsl_user_op +def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]: + out_f32x2 = llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32()]), + [Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()], + "{\n\t" + ".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t" + ".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t" + ".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t" + "max.ftz.f32 f1, $2, 0fC2FE0000;\n\t" + "max.ftz.f32 f2, $3, 0fC2FE0000;\n\t" + "mov.b64 l1, {f1, f2};\n\t" + "mov.f32 f3, 0f4B400000;\n\t" + "mov.b64 l2, {f3, f3};\n\t" + "add.rm.ftz.f32x2 l7, l1, l2;\n\t" + "sub.rn.ftz.f32x2 l8, l7, l2;\n\t" + "sub.rn.ftz.f32x2 l9, l1, l8;\n\t" + "mov.f32 f7, 0f3D9DF09D;\n\t" + "mov.b64 l6, {f7, f7};\n\t" + "mov.f32 f6, 0f3E6906A4;\n\t" + "mov.b64 l5, {f6, f6};\n\t" + "mov.f32 f5, 0f3F31F519;\n\t" + "mov.b64 l4, {f5, f5};\n\t" + "mov.f32 f4, 0f3F800000;\n\t" + "mov.b64 l3, {f4, f4};\n\t" + "fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t" + "fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t" + "fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t" + "mov.b64 {r1, r2}, l7;\n\t" + "mov.b64 {r3, r4}, l10;\n\t" + "shl.b32 r5, r1, 23;\n\t" + "add.s32 r7, r5, r3;\n\t" + "shl.b32 r6, r2, 23;\n\t" + "add.s32 r8, r6, r4;\n\t" + "mov.b32 $0, r7;\n\t" + "mov.b32 $1, r8;\n\t" + "}\n", + "=r,=r,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip)) + out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip)) + return out0, out1 + + +@dsl_user_op +def domain_offset_aligned( + coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None +) -> cute.Tensor: + assert isinstance(tensor.iterator, cute.Pointer) + # We assume that applying the offset does not change the pointer alignment + new_ptr = cute.make_ptr( + tensor.element_type, + elem_pointer(tensor, coord).toint(), + tensor.memspace, + assumed_align=tensor.iterator.alignment, + ) + return cute.make_tensor(new_ptr, tensor.layout) + + +@dsl_user_op +def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: + flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord)) + flat_stride = cute.flatten_to_tuple(tensor.stride) + assert len(flat_coord_i64) == len(flat_stride), ( + "Coordinate and stride must have the same length" + ) + offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride)) + assert isinstance(tensor.iterator, cute.Pointer) + # HACK: we assume that applying the offset does not change the pointer alignment + new_ptr = cute.make_ptr( + tensor.element_type, + tensor.iterator.toint() + offset * tensor.element_type.width // 8, + tensor.memspace, + assumed_align=tensor.iterator.max_alignment, + ) + return cute.make_tensor(new_ptr, tensor.layout) + + +@dsl_user_op +def coord_offset_i64( + tensor: cute.Tensor, idx: cute.typing.Int, dim: int, *, loc=None, ip=None +) -> cute.Tensor: + offset = cutlass.Int64(idx) * cute.size(tensor.stride[dim]) + assert isinstance(tensor.iterator, cute.Pointer) + # HACK: we assume that applying the offset does not change the pointer alignment + new_ptr = cute.make_ptr( + tensor.element_type, + tensor.iterator.toint() + offset * tensor.element_type.width // 8, + tensor.memspace, + assumed_align=tensor.iterator.max_alignment, + ) + new_layout = cute.slice_( + tensor.layout, (*[None] * dim, 0, *[None] * (cute.rank(tensor) - dim - 1)) + ) + return cute.make_tensor(new_ptr, new_layout) + + +@cute.jit +def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA: + """Convert a scalar to a cute TensorSSA of shape (1,) and given dtype""" + vec = cute.make_fragment(1, dtype) + vec[0] = a + return vec.load() + + +def ssa_to_scalar(val): + """Could inline but nice for reflecting the above api""" + return val[0] diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 30134990d68..865f1db5432 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -38,11 +38,6 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal): return 64 if (not is_dropout and is_causal) else 32 else: return 64 if not is_dropout else 32 - elif head_dim <= 160: - if is_sm8x: - return 64 - else: - return 32 elif head_dim <= 192: return 64 elif head_dim <= 224: @@ -132,7 +127,10 @@ def _flash_attn_forward_fake( softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout) p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) if return_softmax: - p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout) + if torch.cuda.is_available() and torch.version.hip: + p = torch.empty((batch_size, num_heads, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device, layout=q.layout) + else: + p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout) rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) return out, softmax_lse, p, rng_state @@ -225,10 +223,11 @@ def _flash_attn_varlen_forward_fake( out = torch.empty_like(q) softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout) p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) - seqlen_q_rounded = round_multiple(max_seqlen_q, 128) - seqlen_k_rounded = round_multiple(max_seqlen_k, 128) if return_softmax: - p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout) + if torch.cuda.is_available() and torch.version.hip: + p = torch.empty((batch_size, num_heads, max_seqlen_q, max_seqlen_k), dtype=q.dtype, device=q.device, layout=q.layout) + else: + p = torch.empty((batch_size, num_heads, round_multiple(max_seqlen_q, 128), round_multiple(max_seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout) rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) return out, softmax_lse, p, rng_state @@ -320,7 +319,10 @@ def _flash_attn_backward_fake( if dv is None: dv = torch.empty_like(v) batch_size, seqlen_q, num_heads, _ = q.shape - softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32) + if torch.cuda.is_available() and torch.version.hip: + softmax_d = torch.empty((batch_size, num_heads, seqlen_q), device=q.device, dtype=torch.float32) + else: + softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32) return softmax_d @@ -431,7 +433,10 @@ def _flash_attn_varlen_backward_fake( dk = torch.empty_like(k) if dv is None: dv = torch.empty_like(v) - softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32) + if torch.cuda.is_available() and torch.version.hip: + softmax_d = torch.empty((num_heads, total_q), device=q.device, dtype=torch.float32) + else: + softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32) return softmax_d @@ -1581,7 +1586,7 @@ def flash_attn_with_kvcache( softmax_scale = q.shape[-1] ** (-0.5) if cache_seqlens is not None and isinstance(cache_seqlens, int): cache_seqlens = torch.full( - (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device + (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device ) cache_seqlens = maybe_contiguous(cache_seqlens) cache_batch_idx = maybe_contiguous(cache_batch_idx) diff --git a/flash_attn/flash_attn_triton_amd/Dockerfile b/flash_attn/flash_attn_triton_amd/Dockerfile new file mode 100644 index 00000000000..29a2c0c43ec --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/Dockerfile @@ -0,0 +1,17 @@ +FROM rocm/pytorch:latest + +WORKDIR /workspace + +# install triton +RUN pip install triton==3.2.0 + +# install flash attention +ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + +RUN git clone https://github.com/ROCm/flash-attention.git &&\ + cd flash-attention &&\ + git checkout main_perf &&\ + python setup.py install + +# set working dir +WORKDIR /workspace/flash-attention \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/README.md b/flash_attn/flash_attn_triton_amd/README.md index 798d78a12d9..2d8fd8e70f3 100644 --- a/flash_attn/flash_attn_triton_amd/README.md +++ b/flash_attn/flash_attn_triton_amd/README.md @@ -11,39 +11,103 @@ These features are supported in Fwd and Bwd 2) Variable sequence lengths 3) Arbitrary Q and KV sequence lengths 4) Arbitrary head sizes +5) Multi and grouped query attention +6) Dropout +7) Rotary embeddings +8) ALiBi -These features are supported in Fwd for now. We will add them to backward soon. -1) Multi and grouped query attention -2) ALiBi and matrix bias - -These features are in development +We are working on the following things 1) Paged Attention 2) Sliding Window -3) Rotary embeddings -4) Dropout -5) Performance Improvements +3) FP8 +4) Performance Improvements -#### Getting Started +##### Getting Started To get started with the triton backend for AMD, follow the steps below. -First install the recommended Triton [commit](https://github.com/triton-lang/triton/commit/3ca2f498e98ed7249b82722587c511a5610e00c4). +First install the recommended Triton version ``` -git clone https://github.com/triton-lang/triton -cd triton -git checkout 3ca2f498e98ed7249b82722587c511a5610e00c4 -pip install --verbose -e python +pip install triton==3.2.0 ``` -Then install and test Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`. +Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`. ``` -export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" cd flash-attention -python setup.py install -pytest tests/test_flash_attn.py +git checkout main_perf +FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install +``` + +To test that things are working, you can run our tests. These tests take hours so you don't need to run the full thing. +``` +FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py +``` + +You can use autotune for better performance by using this flag `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"` +``` +FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE" python $PATH_TO_CODE +``` + +###### Docker +You can also use the Dockerfile below which does the above steps on top of the latest rocm/pytorch image. +``` +FROM rocm/pytorch:latest + +WORKDIR /workspace + +# install triton +RUN pip install triton==3.2.0 + +# install flash attention +ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + +RUN git clone https://github.com/ROCm/flash-attention.git &&\ + cd flash-attention &&\ + git checkout main_perf &&\ + python setup.py install + +# set working dir +WORKDIR /workspace/flash-attention ``` -#### Credits +To build the docker file +``` +docker build -t fa_triton . +``` + +To run the docker image +``` +docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri fa_triton +``` + +###### FP8 +In our fork We have created the following api functions that use fp8 to compute their values. These functions are `flash_attn_fp8_func`, `flash_attn_varlen_fp8_func`, `flash_attn_qkvpacked_fp8_func` and `flash_attn_varlen_qkvpacked_fp8_func`. To use these functions just call them with like the other api functions, the casting will be handled internally. For example + +``` +from flash_attn import flash_attn_qkvpacked_fp8_func + +# forward pass +out, lse, S_dmask = flash_attn_qkvpacked_fp8_func( + qkv, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + +# backward pass +do = torch.randn_like(out) +dqkv = torch.autograd.grad(out, (qkv), do) +``` + +You can use the other api functions in a similar way. + + + +##### Credits AMD Triton kernels team OpenAI kernel team diff --git a/flash_attn/flash_attn_triton_amd/bench.py b/flash_attn/flash_attn_triton_amd/bench.py old mode 100644 new mode 100755 index 91939f831f0..05e64c349be --- a/flash_attn/flash_attn_triton_amd/bench.py +++ b/flash_attn/flash_attn_triton_amd/bench.py @@ -1,290 +1,1223 @@ -import argparse +import os +import sys import torch import triton -from flash_attn.flash_attn_triton_amd.utils import ( - MetaData, - input_helper, - varlen_input_helper, -) -from flash_attn.flash_attn_triton_amd.interface_torch import attention_prefill, attention_decode - -ARGS_TO_TORCH_DTYPE = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, +import time +import argparse +import itertools +import pandas as pd +from logging import warning +from typing import Dict, List, Literal, Optional, Tuple +from dataclasses import dataclass +from functools import lru_cache +from utils import get_arch, input_helper + +DEBUG = False + +ENV_FLAGS = ["FLASH_ATTENTION_TRITON_AMD_ENABLE", "FLASH_ATTENTION_TRITON_AMD_AUTOTUNE", "FLASH_ATTENTION_TRITON_AMD_DEBUG"] + +FUNCTIONS = [ + "flash_attn_func", + "flash_attn_fp8_func", + "flash_attn_kvpacked_func", + "flash_attn_qkvpacked_func", + "flash_attn_qkvpacked_fp8_func", + "flash_attn_varlen_func", + "flash_attn_varlen_fp8_func", + "flash_attn_varlen_kvpacked_func", + "flash_attn_varlen_qkvpacked_func", + "flash_attn_varlen_qkvpacked_fp8_func", + "flash_attn_with_kvcache", +] + +SUPPORTED_DTYPES = { + "flash_attn_func": [torch.float16], + "flash_attn_fp8_func": [torch.float8_e4m3fnuz], + "flash_attn_kvpacked_func": [torch.float16], + "flash_attn_qkvpacked_func": [torch.float16], + "flash_attn_qkvpacked_fp8_func": [torch.float16], + "flash_attn_varlen_func": [torch.float16], + "flash_attn_varlen_fp8_func": [torch.float8_e4m3fnuz], + "flash_attn_varlen_kvpacked_func": [torch.float16], + "flash_attn_varlen_qkvpacked_func": [torch.float16], + "flash_attn_varlen_qkvpacked_fp8_func": [torch.float16], + "flash_attn_with_kvcache": [torch.float16], +} + +SUPPORTED_BACKENDS = { + "flash_attn_func": ["ck", "triton"], + "flash_attn_fp8_func": ["triton"], + "flash_attn_kvpacked_func": ["ck", "triton"], + "flash_attn_qkvpacked_func": ["ck", "triton"], + "flash_attn_qkvpacked_fp8_func": ["triton"], + "flash_attn_varlen_func": ["ck", "triton"], + "flash_attn_varlen_fp8_func": ["triton"], + "flash_attn_varlen_kvpacked_func": ["ck", "triton"], + "flash_attn_varlen_qkvpacked_func": ["ck", "triton"], + "flash_attn_varlen_qkvpacked_fp8_func": ["triton"], + "flash_attn_with_kvcache": ["ck", "triton"], } -FUNCTIONS = { - "prefill": attention_prefill, - "decode": attention_decode +VALID_MODES = ['fwd', 'bwd', 'full'] +SUPPORTED_MODES = { + "flash_attn_func": ["fwd", "bwd", "full"], + "flash_attn_fp8_func": ["fwd", "bwd", "full"], + "flash_attn_kvpacked_func": ["fwd", "bwd", "full"], + "flash_attn_qkvpacked_func": ["fwd", "bwd", "full"], + "flash_attn_qkvpacked_fp8_func": ["fwd", "bwd", "full"], + "flash_attn_varlen_func": ["fwd", "bwd", "full"], + "flash_attn_varlen_fp8_func": ["fwd", "bwd", "full"], + "flash_attn_varlen_kvpacked_func": ["fwd", "bwd", "full"], + "flash_attn_varlen_qkvpacked_func": ["fwd", "bwd", "full"], + "flash_attn_varlen_qkvpacked_fp8_func": ["fwd", "bwd", "full"], + "flash_attn_with_kvcache": ["fwd"], } -def get_benchmark_configs(args, varlen=False): +@dataclass +class EnvVariableConfig: + key: str + values: List[str] + backend: Optional[Literal["triton", "ck"]] = None + +ENV_VARIABLE_CONFIGS : List[EnvVariableConfig] = [ + EnvVariableConfig(key="BWD_MODE", values=["split", "fused", "jingning"], backend="triton"), +] + +class FunctionConfig: + def __init__(self, fn_name: str, mode: Literal["fwd", "bwd", "full"], dtype, backend: Literal["triton", "ck"], env_config: Dict): + self.fn_name = fn_name + self.mode: Literal["fwd", "bwd", "full"] = mode + self.dtype = dtype + self.backend: Literal["triton", "ck"] = backend + self.arch = get_arch() + self.env_configs = env_config + + def __str__(self): + # extract base dtype name if it's a torch dtype + dtype_str = str(self.dtype) + if "torch." in dtype_str: + dtype_str = dtype_str.split(".")[-1] + + if len(self.env_configs) > 0: + env_str = "" + for env_key, env_value in self.env_configs.items(): + env_str += f"{env_key}={env_value}" + return f"{self.fn_name}_{self.mode}_{dtype_str}_{self.backend}_{self.arch}_{env_str}" + else: + return f"{self.fn_name}_{self.mode}_{dtype_str}_{self.backend}_{self.arch}" + + def column_name(self): + return f"{self}_ms" + + +@lru_cache() +def available_backends(): + available = [] + + # try to load each backend + for backend in ["triton", "ck"]: + try: + # try loading the module with this backend + flash_attn = load_flash_attn_module(backend) + + # if we got here, the backend loaded successfully + available.append(backend) + except Exception as e: + # backend not available, just continue + print(f"Backend {backend} not available. Error: {e}") + + # if no backends available, default to triton + if not available: + raise ValueError("No Backends available") + + return available + +@lru_cache() +def get_fn_params(fn_name): + # get params for fn + packing = get_packing_type(fn_name) + is_varlen = True if "varlen" in fn_name else False + is_fp8 = True if "fp8" in fn_name else False + supported_dtypes = SUPPORTED_DTYPES.get(fn_name, [torch.float16]) # default to float16 if not found + supported_backends = [backend for backend in SUPPORTED_BACKENDS.get(fn_name, ["triton"]) if backend in available_backends()] # default to triton backend + supports_backward = False if fn_name in ["flash_attn_with_kvcache"] else True + supported_modes = SUPPORTED_MODES.get(fn_name, ["fwd"]) + device = "cuda" + + # get supported env configs for each backend + supported_env_configs = {} + for backend in supported_backends: + supported_env_configs[backend] = get_env_value_combinations(backend) + + # check backward pass support + if not supports_backward: + warning(f"{fn_name} does not have a backward pass so benching forward pass only.") + + return is_varlen, is_fp8, packing, supported_dtypes, supported_backends, supported_modes, supported_env_configs, device + +def generate_fn_inputs( + fn_name: str, + BATCH: int, + HQ: int, + HK: int, + N_CTX_Q: int, + N_CTX_K: int, + D_HEAD: int, + CAUSAL: bool, + DROPOUT_P: float, + dtype: torch.dtype, + device: Literal["cpu", "cuda"] + ): + if fn_name == "flash_attn_func": + return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", device=device) + elif fn_name == "flash_attn_kvpacked_func": + return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", packing="kv", device=device) + elif fn_name == "flash_attn_qkvpacked_func": + return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", packing="qkv", device=device) + elif fn_name == "flash_attn_varlen_func": + return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="thd", device=device) + elif fn_name == "flash_attn_varlen_kvpacked_func": + return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="thd", packing="kv", device=device) + elif fn_name == "flash_attn_varlen_qkvpacked_func": + return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="thd", packing="qkv", device=device) + elif fn_name == "flash_attn_with_kvcache": + return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", device=device) + elif fn_name == "flash_attn_fp8_func": + return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", device=device) + elif fn_name == "flash_attn_qkvpacked_fp8_func": + return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", packing="qkv", device=device) + elif fn_name == "flash_attn_varlen_fp8_func": + return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="thd", device=device) + elif fn_name == "flash_attn_varlen_qkvpacked_fp8_func": + return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="thd", packing="qkv", device=device) + else: + valid_fn_names = ", ".join(FUNCTIONS) + raise ValueError(f"{fn_name} should be one of the following functions. {valid_fn_names}") + +def estimate_memory(config): + batch, hq, hk, sq, sk, d_head, causal, dropout = config + memory_estimate = batch * (hq * sq + hk * sk) * d_head * 4 # bytes + return memory_estimate + +def generate_benchmark_configs(is_varlen: bool, packing: Optional[Literal["kv", "qkv"]]): """ - Returns benchmark configurations based on whether variable-length sequences are used. + generates a small number of configs that cover the parameter space well """ - if args.custom_config: - hk = args.hq if not args.hk else args.hk - sk = args.sq if not args.sk else args.sk - return [(args.b, args.hq, hk, args.sq, sk)] - elif varlen: - return [ - (2, 16, 4, 1024, 1024), - (8, 16, 2, 2048, 2048), - (4, 16, 8, 4096, 4096), - (2, 16, 4, 8192, 8192), - (2, 16, 8, 16384, 16384), - (2, 48, 12, 1024, 1024), - (2, 48, 24, 2048, 2048), - (2, 48, 8, 4096, 4096), - (2, 48, 4, 8192, 8192), - (2, 48, 2, 16384, 16384), - (2, 64, 32, 1024, 1024), - (4, 64, 16, 2048, 2048), - (4, 64, 8, 4096, 4096), - (4, 64, 32, 8192, 8192), - (4, 128, 16, 16384, 16384), - ] + + # define all parameter options as lists + batch_sizes = [1, 64] + if packing == "qkv": + hq_values = hk_values = [2, 8] + sq_values = sk_values = [256, 8192] else: - return [ - (16, 16, 16, 1024, 1024), - (8, 16, 16, 2048, 2048), - (4, 16, 16, 4096, 4096), - (1, 8, 8, 8192, 8192), - (1, 2, 2, 16384, 16384), - (2, 48, 48, 1024, 1024), - (2, 48, 48, 2048, 1024), - (1, 8, 8, 4096, 8192), - (1, 8, 8, 8192, 4096), - (2, 4, 4, 16384, 8192), - (2, 8, 8, 1989, 15344), - (4, 16, 16, 4097, 163), - (2, 16, 16, 8122, 2159), - (1, 16, 16, 16281, 7), - (2, 48, 48, 1021, 1020), - (2, 48, 48, 2001, 2048), - (2, 8, 8, 3996, 9639), - (2, 8, 8, 8181, 1021), - ] + if is_varlen: # make sure the seqlen is greater than the batchsize so that subsequences are greater than 0 + hq_values = [16, 32] # test mqa/gqa + hk_values = [8, 16] + sq_values = [128, 512] + sk_values = [512, 2024] + else: + hq_values = [64, 128] # test mqa/gqa + hk_values = [16, 64] + sq_values = [4, 4096] + sk_values = [4096, 16384] # test large k values for inference perf + d_head_values = [64, 128] + causal_values = [True, False] # most models usual causal True + dropout_values = [0.0, 0.1] + + # generate all fn_configs without inputs + input_configs = [] + + # one big loop to generate configs + for batch in batch_sizes: + for hq in hq_values: + for hk in hk_values: + for sq in sq_values: + for sk in sk_values: + for d_head in d_head_values: + for causal in causal_values: + for dropout in dropout_values: + # filter configs + input_config = (batch, hq, hk, sq, sk, d_head, causal, dropout) + + # skip if memory usage would be too high + if estimate_memory(input_config) > 8 * 1024 * 1024 * 1024: # 8 GB limit + continue + + # we need hq to be a multiple of hk + if hq % hk != 0: + continue + + # for qkvpacked functions, q and k must have same dimensions + if packing == "qkv" and (sq != sk or hq != hk): + continue + + input_configs.append(input_config) + + return input_configs + +def create_benchmark_fn( + flash_attn, + fn_name, + fn_input, + mode: Literal["fwd", "bwd", "full"] +): + if DEBUG: + print("create_benchmark_fn") + print("flash_attn:", flash_attn) + print("fn_name:", fn_name) + print("fn_input:", len(fn_input)) + print("mode:", mode) + + if fn_name == "flash_attn_func": + q, k, v, do, metadata = fn_input + if mode == "fwd": + def flash_attn_bench_fn(): + out, lse, S_dmask = flash_attn.flash_attn_func( + q, + k, + v, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + return out + elif mode == "bwd": + out, lse, S_dmask = flash_attn.flash_attn_func( + q, + k, + v, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + def flash_attn_bench_fn(): + dq, dk, dv = torch.autograd.grad(out, (q, k, v), do, retain_graph=True) + return dq, dk, dv + elif mode == "full": + def flash_attn_bench_fn(): + out, lse, S_dmask = flash_attn.flash_attn_func( + q, + k, + v, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + dq, dk, dv = torch.autograd.grad(out, (q, k, v), do, retain_graph=True) + return dq, dk, dv + else: + raise ValueError(f"Unsupported benchmarking mode: {mode}") + + return flash_attn_bench_fn -def gen_fn_inputs(fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device, layout, causal): - flops_per_matmul = 0 - - if fn_name.startswith("prefill"): - if layout == "thd": - q, k, v, input_metadata = varlen_input_helper( - BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device=device) - for i in range(input_metadata.num_contexts): - seqlen_q = input_metadata.cu_seqlens_q[i + 1] - input_metadata.cu_seqlens_q[i] - seqlen_k = input_metadata.cu_seqlens_k[i + 1] - input_metadata.cu_seqlens_k[i] - flops_per_matmul += seqlen_q.item() * seqlen_k.item() * HQ * D_HEAD * 2 + elif fn_name == "flash_attn_kvpacked_func": + q, kv, do, metadata = fn_input + if mode == "fwd": + def flash_attn_kvpacked_bench_fn(): + out, lse, S_dmask = flash_attn.flash_attn_kvpacked_func( + q, + kv, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + return out + elif mode == "bwd": + out, lse, S_dmask = flash_attn.flash_attn_kvpacked_func( + q, + kv, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + def flash_attn_kvpacked_bench_fn(): + dq, dkv = torch.autograd.grad(out, (q, kv), do, retain_graph=True) + return dq, dkv + elif mode == "full": + def flash_attn_kvpacked_bench_fn(): + out, lse, S_dmask = flash_attn.flash_attn_kvpacked_func( + q, + kv, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + dq, dkv = torch.autograd.grad(out, (q, kv), do, retain_graph=True) + return dq, dkv else: - q, k, v, input_metadata = input_helper( - BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device=device + raise ValueError(f"Unsupported benchmarking mode: {mode}") + + return flash_attn_kvpacked_bench_fn + elif fn_name == "flash_attn_qkvpacked_func": + qkv, do, metadata = fn_input + if mode == "fwd": + def flash_attn_qkvpacked_bench_fn(): + out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_func( + qkv, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + return out + elif mode == "bwd": + out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_func( + qkv, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + def flash_attn_qkvpacked_bench_fn(): + dqkv = torch.autograd.grad(out, (qkv), do, retain_graph=True) + return dqkv + elif mode == "full": + def flash_attn_qkvpacked_bench_fn(): + out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_func( + qkv, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + dqkv = torch.autograd.grad(out, (qkv), do, retain_graph=True) + return dqkv + else: + raise ValueError(f"Unsupported benchmarking mode: {mode}") + + return flash_attn_qkvpacked_bench_fn + elif fn_name == "flash_attn_varlen_func": + q_unpad, k_unpad, v_unpad, do_unpad, metadata = fn_input + if mode == "fwd": + def flash_attn_varlen_bench_fn(): + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + return out_unpad + elif mode == "bwd": + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + def flash_attn_varlen_bench_fn(): + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), do_unpad, retain_graph=True) + return dq_unpad, dk_unpad, dv_unpad + elif mode == "full": + def flash_attn_varlen_bench_fn(): + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), do_unpad, retain_graph=True) + return dq_unpad, dk_unpad, dv_unpad + else: + raise ValueError(f"Unsupported benchmarking mode: {mode}") + + return flash_attn_varlen_bench_fn + elif fn_name == "flash_attn_varlen_kvpacked_func": + q_unpad, kv_unpad, do_unpad, metadata = fn_input + if mode == "fwd": + def flash_attn_varlen_kvpacked_bench_fn(): + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_kvpacked_func( + q_unpad, + kv_unpad, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + return out_unpad + elif mode == "bwd": + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_kvpacked_func( + q_unpad, + kv_unpad, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, ) - flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD - - if causal: - input_metadata.need_causal() - - o = torch.empty_like(q) - input_data = (q, k, v, o, input_metadata) - elif fn_name.startswith("decode"): - q = torch.randn( - [BATCH, N_CTX_Q, HK, HQ // HK, D_HEAD], - device=device, - dtype=dtype, - requires_grad=False, - ) - k = torch.randn( - [BATCH, N_CTX_K, HK, 1, D_HEAD], - device=device, - dtype=dtype, - requires_grad=False, - ).expand(-1, -1, -1, HQ // HK, -1) - v = torch.randn( - [BATCH, N_CTX_K, HK, 1, D_HEAD], - device=device, - dtype=dtype, - requires_grad=False, - ).expand(-1, -1, -1, HQ // HK, -1) - input_metadata = MetaData(sm_scale=1.3) - input_metadata.layout = "bsghd" - - # Adjust flops calculation if needed - flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD + def flash_attn_varlen_kvpacked_bench_fn(): + dq_unpad, dkv_unpad = torch.autograd.grad(out_unpad, (q_unpad, kv_unpad), do_unpad, retain_graph=True) + return dq_unpad, dkv_unpad + elif mode == "full": + def flash_attn_varlen_kvpacked_bench_fn(): + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_kvpacked_func( + q_unpad, + kv_unpad, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + dq_unpad, dkv_unpad = torch.autograd.grad(out_unpad, (q_unpad, kv_unpad), do_unpad, retain_graph=True) + return dq_unpad, dkv_unpad + else: + raise ValueError(f"Unsupported benchmarking mode: {mode}") + + return flash_attn_varlen_kvpacked_bench_fn + elif fn_name == "flash_attn_varlen_qkvpacked_func": + qkv_unpad, do_unpad, metadata = fn_input + if mode == "fwd": + def flash_attn_varlen_qkvpacked_bench_fn(): + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_func( + qkv_unpad, + metadata.cu_seqlens_q, + metadata.max_seqlens_q, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + return out_unpad + elif mode == "bwd": + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_func( + qkv_unpad, + metadata.cu_seqlens_q, + metadata.max_seqlens_q, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + def flash_attn_varlen_qkvpacked_bench_fn(): + dqkv_unpad = torch.autograd.grad(out_unpad, (qkv_unpad), do_unpad, retain_graph=True) + return dqkv_unpad + elif mode == "full": + def flash_attn_varlen_qkvpacked_bench_fn(): + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_func( + qkv_unpad, + metadata.cu_seqlens_q, + metadata.max_seqlens_q, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + dqkv_unpad = torch.autograd.grad(out_unpad, (qkv_unpad), do_unpad, retain_graph=True) + return dqkv_unpad + else: + raise ValueError(f"Unsupported benchmarking mode: {mode}") + + return flash_attn_varlen_qkvpacked_bench_fn + elif fn_name == "flash_attn_with_kvcache": + q, k_cache, v_cache, _, metadata = fn_input + if mode == "fwd": + def flash_attn_with_kvcache_bench_fn(): + out = flash_attn.flash_attn_with_kvcache( + q, + k_cache, + v_cache, + None, + None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens=None, + cache_batch_idx=None, + cache_leftpad=None, + block_table=None, + causal=metadata.causal, + window_size=(-1, -1), + rotary_interleaved=False, + alibi_slopes=None, + num_splits=0, + ) + return out + else: + raise ValueError(f"Unsupported benchmarking mode: {mode}") + + return flash_attn_with_kvcache_bench_fn + elif fn_name == "flash_attn_fp8_func": + (q, descale_q), (k, descale_k), (v, descale_v), (do, descale_do), metadata = fn_input + if mode == "fwd": + def flash_attn_f8_bench_fn(): + out, lse, S_dmask = flash_attn.flash_attn_fp8_func( + q, + k, + v, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + return out + elif mode == "bwd": + out, lse, S_dmask = flash_attn.flash_attn_fp8_func( + q, + k, + v, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + def flash_attn_f8_bench_fn(): + dq, dk, dv = torch.autograd.grad(out, (q, k, v), do, retain_graph=True) + return dq, dk, dv + elif mode == "full": + def flash_attn_f8_bench_fn(): + out, lse, S_dmask = flash_attn.flash_attn_fp8_func( + q, + k, + v, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + dq, dk, dv = torch.autograd.grad(out, (q, k, v), do, retain_graph=True) + return dq, dk, dv + else: + raise ValueError(f"Unsupported benchmarking mode: {mode}") - input_data = (q, k, v, input_metadata) + return flash_attn_f8_bench_fn + elif fn_name == "flash_attn_qkvpacked_fp8_func": + qkv, do, metadata = fn_input + if mode == "fwd": + def flash_attn_qkvpacked_fp8_bench_fn(): + out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_fp8_func( + qkv, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + return out + elif mode == "bwd": + out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_fp8_func( + qkv, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + def flash_attn_qkvpacked_fp8_bench_fn(): + dqkv = torch.autograd.grad(out, (qkv), do, retain_graph=True) + return dqkv + elif mode == "full": + def flash_attn_qkvpacked_fp8_bench_fn(): + out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_fp8_func( + qkv, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + dqkv = torch.autograd.grad(out, (qkv), do, retain_graph=True) + return dqkv + else: + raise ValueError(f"Unsupported benchmarking mode: {mode}") + + return flash_attn_qkvpacked_fp8_bench_fn + elif fn_name == "flash_attn_varlen_fp8_func": + (q_unpad, descale_q), (k_unpad, descale_k), (v_unpad, descale_v), (do_unpad, descale_do), metadata = fn_input + if mode == "fwd": + def flash_attn_varlen_fp8_bench_fn(): + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_fp8_func( + q_unpad, + k_unpad, + v_unpad, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + return out_unpad + elif mode == "bwd": + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_fp8_func( + q_unpad, + k_unpad, + v_unpad, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + def flash_attn_varlen_fp8_bench_fn(): + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), do_unpad, retain_graph=True) + return dq_unpad, dk_unpad, dv_unpad + elif mode == "full": + def flash_attn_varlen_fp8_bench_fn(): + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_fp8_func( + q_unpad, + k_unpad, + v_unpad, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), do_unpad, retain_graph=True) + return dq_unpad, dk_unpad, dv_unpad + else: + raise ValueError(f"Unsupported benchmarking mode: {mode}") + + return flash_attn_varlen_fp8_bench_fn + elif fn_name == "flash_attn_varlen_qkvpacked_fp8_func": + qkv_unpad, do_unpad, metadata = fn_input + if mode == "fwd": + def flash_attn_varlen_qkvpacked_fp8_bench_fn(): + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_fp8_func( + qkv_unpad, + metadata.cu_seqlens_q, + metadata.max_seqlens_q, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + return out_unpad + elif mode == "bwd": + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_fp8_func( + qkv_unpad, + metadata.cu_seqlens_q, + metadata.max_seqlens_q, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + def flash_attn_varlen_qkvpacked_fp8_bench_fn(): + dqkv_unpad = torch.autograd.grad(out_unpad, (qkv_unpad), do_unpad, retain_graph=True) + return dqkv_unpad + elif mode == "full": + def flash_attn_varlen_qkvpacked_fp8_bench_fn(): + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_fp8_func( + qkv_unpad, + metadata.cu_seqlens_q, + metadata.max_seqlens_q, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + dqkv_unpad = torch.autograd.grad(out_unpad, (qkv_unpad), do_unpad, retain_graph=True) + return dqkv_unpad + else: + raise ValueError(f"Unsupported benchmarking mode: {mode}") + + return flash_attn_varlen_qkvpacked_fp8_bench_fn else: - raise ValueError("Unsupported benchmark function") - return input_data, flops_per_matmul + valid_fn_names = ", ".join(FUNCTIONS) + raise ValueError(f"{fn_name} should be one of the following functions. {valid_fn_names}") -def run_benchmark(args, fn_name, fn, mode): +def get_packing_type(fn_name: str) -> Optional[Literal["kv", "qkv"]]: + if "_kvpacked" in fn_name: + packing = "kv" + elif "_qkvpacked" in fn_name: + packing = "qkv" + else: + packing = None + + return packing + +def load_flash_attn_module(backend: Literal["triton", "ck"], env_configs: Dict = {}, verbose = False): """ - Runs the benchmark for the provided function based on the provided arguments. + Load the flash_attn module with the specified backend configuration """ - print(f"Benchmarking {fn_name} in {mode} mode...") - dtype = ARGS_TO_TORCH_DTYPE[args.dtype] - head_size = args.d if args.d else 128 - causal = args.causal - varlen = args.layout == "thd" - return_tflops = args.return_tflops - line_names = "TFLOPS" if return_tflops else "Time (ms)" + # remove any existing env variables first + for key in ENV_FLAGS: + if key in os.environ: + del os.environ[key] - # Determine configurations - x_vals_list = get_benchmark_configs(args, varlen=varlen) + # set environment variable for the desired backend + if backend == "triton": + os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE" + os.environ["FLASH_ATTENTION_TRITON_AMD_AUTOTUNE"] = "0" + os.environ["FLASH_ATTENTION_TRITON_AMD_DEBUG"] = "0" + elif backend == "ck": + os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "FALSE" + else: + raise ValueError(f"Unknown backend {backend}") + + # add custom env configs + add_env_configs(env_configs) + + if verbose: + print(f"Loading flash_attn module with {backend} backend.") + + # Remove any existing flash_attn modules from sys.modules + for module_name in list(sys.modules.keys()): + if module_name.startswith('flash_attn'): + del sys.modules[module_name] + + # Clear CUDA cache + torch.cuda.empty_cache() + + # Import and return the module + import flash_attn + + return flash_attn + +def add_env_configs(env_config: Dict): + for env_key, env_value in env_config.items(): + if env_key in os.environ: + del os.environ[env_key] # remove previous version so that env key is the latest key added + os.environ[env_key] = env_value + +def run_benchmark(func_config: FunctionConfig, input_configs): + """ + Runs the benchmark for the provided function configuration with the given input configurations. + """ + # print new line to seperate benchmark runs + print() + if DEBUG: + print("func_config:", func_config) + + # extract function configuration parameters + fn_name = func_config.fn_name + mode = func_config.mode + dtype = func_config.dtype + backend = func_config.backend + + # load flash attention module + flash_attn_module = load_flash_attn_module(backend, func_config.env_configs, verbose=True) + + # start timing the benchmark + start_time = time.time() + + # print bench fn + print(f"Benchmarking {func_config} ...") # Setup benchmark configurations - configs = [ + bench_configs = [ triton.testing.Benchmark( - x_names=["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K"], - x_vals=x_vals_list, + x_names=["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K", "D_HEAD", "CAUSAL", "DROPOUT"], + x_vals=list(input_configs.keys()), line_arg="provider", line_vals=["triton"], - line_names=[line_names], + line_names=["Time (ms)"], styles=[("red", "-")], ylabel="ms", - plot_name=f"benchmark-{fn_name}-d{head_size}-layout{args.layout}-mode{mode}", + plot_name=f"benchmark-{func_config}", args={ - "D_HEAD": head_size, - "dtype": dtype, - "causal": causal, - "mode": mode, }, ) ] - @triton.testing.perf_report(configs) + @triton.testing.perf_report(bench_configs) def bench_function( - BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal, mode, provider, device="cuda" + BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT, provider, device="cuda" ): - warmup = 25 - rep = 100 - flops_per_matmul = 0 + if DEBUG: + print("BATCH:", BATCH) + print("HQ:", HQ) + print("HK:", HK) + print("N_CTX_Q:", N_CTX_Q) + print("N_CTX_Q:", N_CTX_Q) + print("D_HEAD:", D_HEAD) + print("CAUSAL:", CAUSAL) + print("DROPOUT:", DROPOUT) + print("mode:", mode) + print("provider:", provider) + print("device:", device) + fn_input = input_configs[(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT)] + benchmark_fn = create_benchmark_fn(flash_attn_module, fn_name, fn_input, mode) - # generate function inputs - fn_inputs, flops_per_matmul = gen_fn_inputs( - fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device, args.layout, causal - ) + # run the benchmark + ms = triton.testing.do_bench(benchmark_fn, warmup=25, rep=100) + return ms - # define the function to benchmark - if mode == "fwd": - benchmark_fn = lambda: fn(*fn_inputs) - total_flops = 2 * flops_per_matmul - elif mode == "bwd": - outputs = fn(*fn_inputs) - output = outputs[0] - grad_output = torch.randn_like(output) - benchmark_fn = lambda: output.backward(grad_output, retain_graph=True) - total_flops = 2 * flops_per_matmul * 2.5 - else: - raise ValueError("Unsupported mode. Choose 'fwd' or 'bwd'.") + df = bench_function.run(save_path=".", print_data=True, return_df=True)[0] + + # set the column name to reflect the function configuration + df = df.rename(columns={"Time (ms)": func_config.column_name()}) + + # calculate and print elapsed time + elapsed_time = time.time() - start_time + print(f"Total time for benchmarking {fn_name} in {mode} mode with {dtype}: {elapsed_time:.2f} seconds") - if causal: - total_flops *= 0.5 + return df - # Run the benchmark - ms = triton.testing.do_bench(benchmark_fn, warmup=warmup, rep=rep) +def filter_modes(requested_modes, fn_name, supported_modes_for_fn): + modes_to_run = [] + if requested_modes: + for mode in requested_modes: + if mode in supported_modes_for_fn: + modes_to_run.append(mode) + else: + warning(f"Mode '{mode}' requested but not supported by function '{fn_name}'. Skipping this mode for this function.") + else: + modes_to_run = ["full" if "full" in supported_modes_for_fn else "fwd"] + return modes_to_run - if return_tflops: - return total_flops / ms * 1e-9 - else: - return ms +def get_env_value_combinations(current_backend: Optional[Literal["triton", "ck"]]) -> List[Dict[str, str]]: + # filter environment variations applicable to the current backend + applicable_variations = [ + var_config for var_config in ENV_VARIABLE_CONFIGS + if var_config.backend is None or var_config.backend == current_backend + ] - bench_function.run(save_path=".", print_data=True) + if not applicable_variations: + # no applicable variations, return list with empty dict + return [{}] -def supported_layouts(): - """ - Returns a string describing the supported layouts. - """ - return ( - "bhsd: Q, K, V are individual tensors of [batch, num_heads, seqlen_q/k, head_size]\n" - "bshd: Q, K, V are individual tensors of [batch, seqlen_q/k, num_heads, head_size]\n" - "thd: Q, K, V are individual tensors of [total_q/k, num_heads, head_size]\n" - 'This layout is sometimes called "varlen" or "grouped" layout.' - ) + # prepare keys and value lists + variation_keys = [v.key for v in applicable_variations] + variation_value_lists = [v.values for v in applicable_variations] + + # generate all combinations as dictionaries directly + env_configs = [] + for value_combination in itertools.product(*variation_value_lists): + env_configs.append(dict(zip(variation_keys, value_combination))) + + return env_configs + +def get_input_config_set(config_type): + if config_type == "llama": + # batch, hq, hk, sq, sk, d_head, causal, dropout + input_configs = [ + # LLaMA 3 8B + (1, 32, 8, 8192, 8192, 128, True, 0.0), + # LLaMA 3 70B + (1, 64, 8, 8192, 8192, 128, True, 0.0), + ] + else: + raise ValueError(f"Unknown input config: {config_type}") + + return input_configs -def parse_args(): + +def process_args(): """ - Parses command-line arguments. + Parses command-line arguments and returns function configs and input configs. """ + # create parser parser = argparse.ArgumentParser( prog="Benchmark FlashAttention", allow_abbrev=False, ) - parser.add_argument("-b", type=int, default=0) - parser.add_argument("-hq", type=int, default=0) - parser.add_argument("-hk", type=int, default=0) - parser.add_argument("-sq", type=int, default=0) - parser.add_argument("-sk", type=int, default=0) - parser.add_argument( - "-equal_seqlens", - action="store_true", - default=False, - help="If specified, each context within the thd layout has same seqlen as sq and sk", - ) - parser.add_argument("-d", type=int, default=0) - parser.add_argument("-causal", action="store_true", default=False) - parser.add_argument("-dtype", default="fp16") - parser.add_argument("-return_tflops", action="store_true", default=False) - parser.add_argument( - "-layout", - type=str, - default="bhsd", - help=supported_layouts(), - ) + # functions parser.add_argument( "-benchmark_fn", type=str, nargs="*", - choices=FUNCTIONS.keys(), - help="Function(s) to benchmark: prefill, decode, or both", + choices=FUNCTIONS, + required=True, + help=f"Function(s) to benchmark", ) parser.add_argument( - "-mode", + "--mode", type=str, nargs='*', - default=["fwd", "bwd"], - choices=["fwd", "bwd"], - help="Mode(s) to run: 'fwd' for forward pass, 'bwd' for backward pass", + choices=VALID_MODES, + default=None, + help=f"Benchmarking mode(s) to run. If omitted, runs all supported modes for each function.", ) - return parser.parse_args() + # config + parser.add_argument("-b", type=int, default=None, help="Batch size") + parser.add_argument("-hq", type=int, default=None, help="Q Number of heads") + parser.add_argument("-hk", type=int, default=None, help="K and V Number of heads") + parser.add_argument("-sq", type=int, default=None, help="Q Sequence Length") + parser.add_argument("-sk", type=int, default=None, help="K and V Sequence Length") + parser.add_argument("-d", type=int, default=None, help="Head Dimension") + parser.add_argument("-causal", action="store_true", default=None, help="Causal") + parser.add_argument("-dropout", type=float, default=None, help="Dropout") + + # parse args + args = parser.parse_args() + + # parse function args + benchmark_fns = args.benchmark_fn + requested_modes = args.mode + + # fenerate function configurations and input configurations separately + all_function_configs = [] + all_input_configs = {} # Maps function config -> input configs + for fn_name in benchmark_fns: + is_varlen, is_fp8, packing, supported_dtypes, supported_backends, supported_modes_for_fn, supported_env_configs, device = get_fn_params(fn_name) + + # Generate or use custom input configurations + if args.b or args.hq or args.hk or args.sq or args.sk or args.d: + assert args.b and args.hq and args.sq and args.d, ( + "if custom config is specified, please provide at least batch, number of Q heads, Q sequence length, and head size." + ) + + batch = args.b + hq = args.hq + hk = args.hk if args.hk is not None else args.hq + sq = args.sq + sk = args.sk if args.sk is not None else args.sq + d_head = args.d + causal = args.causal if args.causal is not None else False + dropout = args.dropout if args.dropout is not None else 0.0 + input_configs = [(batch, hq, hk, sq, sk, d_head, causal, dropout)] + else: + if True: + input_configs = get_input_config_set("llama") + else: + input_configs = generate_benchmark_configs(is_varlen, packing) + + # filter by mode + modes_to_run = filter_modes(requested_modes, fn_name, supported_modes_for_fn) + if not modes_to_run: + warning(f"No valid modes to run for function '{fn_name}' based on request and function support. Skipping this function.") + continue + + # create a function config for each backend and dtype combination + for backend in supported_backends: + for dtype in supported_dtypes: + for mode in modes_to_run: + for env_config in supported_env_configs[backend]: + func_config = FunctionConfig(fn_name, mode, dtype, backend, env_config) + all_function_configs.append(func_config) + + # Generate inputs for this function configuration + fn_inputs = {} + for input_config in input_configs: + fn_inputs[input_config] = generate_fn_inputs(fn_name, *input_config, dtype, device) + + all_input_configs[func_config] = fn_inputs + + return all_function_configs, all_input_configs + +def check_environment_variables(): + for key in ENV_FLAGS: + if key in os.environ: + raise ValueError(f"Running with {key} environment variable is not recommended for the benching script. Use --help to see how to use the benching script.") def main(): """ Main function to run benchmarks. """ - args = parse_args() - - # Validate arguments - assert ( - args.layout == "thd" or not args.equal_seqlens - ), "Equal sequence lengths arg must be used with the thd layout." - args.custom_config = False - if args.b or args.hq or args.hk or args.sq or args.sk or args.d: - args.custom_config = True - assert args.b and args.hq and args.sq and args.d, ( - "If custom config is specified, please provide all of batch, " - "number of Q heads, Q sequence length, and head size." - ) - assert args.dtype in ARGS_TO_TORCH_DTYPE, "Only fp16, bf16 and fp32 types currently supported." + # check environment variables + check_environment_variables() - # determine the functions to benchmark - if args.benchmark_fn is None or len(args.benchmark_fn) == 0: - bench_fn_list = FUNCTIONS.keys() - else: - bench_fn_list = args.benchmark_fn - - # benchmark functions - for fn_name in bench_fn_list: - if fn_name not in FUNCTIONS: - raise ValueError(f"Invalid benchmark function specified: {fn_name}") - for mode in args.mode: - if fn_name == "decode" and mode == "bwd": - print(f"Decode kernel doesnot have a backward pass") - continue - run_benchmark(args, fn_name, FUNCTIONS[fn_name], mode) + # start timing the entire benchmarking process + total_start_time = time.time() + + # process args to get function configs and input configs + function_configs, all_input_configs = process_args() + + # Check if we have multiple function configurations + has_multiple_func_configs = len(function_configs) > 1 + combined_df = None + + # run benchmarks for each function configuration + for func_config in function_configs: + # run benchmark with the input configs for this function config + input_configs = all_input_configs[func_config] + df = run_benchmark(func_config, input_configs) + + # Define the columns that represent input configurations + input_config_cols = ["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K", "D_HEAD", "CAUSAL", "DROPOUT"] + + # merge into one final dataframe + if combined_df is None: + combined_df = df + else: + # Ensure we're joining on input configuration columns + combined_df = combined_df.merge(df, on=input_config_cols, how="outer") + + + # print new line to seperate the combined data information from the benchmark specific information + print() + + # print total time for all benchmarks + total_elapsed_time = time.time() - total_start_time + print(f"Total time for all benchmarks: {total_elapsed_time:.2f} seconds") + + # save combined data and make comparisons if we have multiple function configs + if has_multiple_func_configs: + if len(function_configs) == 2: + func1 = function_configs[0] + func2 = function_configs[1] + + # construct column names for the timing results + col1 = func1.column_name() + col2 = func2.column_name() + + # Check if we're comparing triton vs ck (in either order) + is_triton_vs_ck = ( + (func1.backend == "triton" and func2.backend == "ck") or + (func1.backend == "ck" and func2.backend == "triton") + ) + + # For triton vs ck comparisons + if is_triton_vs_ck: + # For triton vs ck comparisons, always make triton the baseline + if func1.backend == "triton" and func2.backend == "ck": + triton_col = col1 + ck_col = col2 + ratio_col = f"ck_to_triton_ratio" + else: + triton_col = col2 + ck_col = col1 + ratio_col = f"ck_to_triton_ratio" + + # Calculate ratio: ck_time / triton_time (values > 1 mean triton is faster) + combined_df[ratio_col] = combined_df[ck_col] / combined_df[triton_col] + + # print explanation + print(f"Comparison Results (triton vs ck):") + print(f"Ratio values: values > 1 mean triton is faster (by that factor), values < 1 mean ck is faster") + elif False: + # For other comparisons, use the standard approach + ratio_col = f"{func1}_to_{func2}_ratio" + + # Calculate the ratio + combined_df[ratio_col] = combined_df[col2] / combined_df[col1] + + # print explanation + print(f"Comparison Results ({func1} vs {func2}):") + print(f"Ratio values: values > 1 mean {func1} is faster than {func2} (by that factor), values < 1 mean slower") + + print(f"Combined data:") + print(combined_df) + + # save csv & markdown + combined_filename = f"benchmark_combined" + combined_df.to_csv(f"{combined_filename}.csv", index=False) + with open(f"{combined_filename}.md", 'w') as f: + f.write(combined_df.to_markdown(index=False, floatfmt=".2f")) if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill.py b/flash_attn/flash_attn_triton_amd/bwd_prefill.py index 84212235a64..44e2c294b0d 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill.py @@ -1,10 +1,16 @@ +from typing import Literal, Optional import torch import triton import triton.language as tl -from .utils import get_shape_from_layout, get_strides_from_layout, DEBUG, PERF +from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, compute_fp8_scaling_factors, get_shapes_from_layout, get_strides_from_layout, is_fp8, write_dropout_mask, create_dropout_mask + +# TODO: move this into utils.py so it's shared among kernels +# NOTE: triton fails to import tl.constexprs so create them here for the file +tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) +tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) @triton.jit -def _bwd_preprocess_use_o( +def _bwd_preprocess( Out, DO, Delta, @@ -15,16 +21,18 @@ def _bwd_preprocess_use_o( cu_seqlens_k, max_seqlen_q, max_seqlen_k, + DESCALE_do, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, N_CTX_Q: tl.constexpr, Z: tl.constexpr, H: tl.constexpr, - IS_VARLEN: tl.constexpr + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, ): - pid_m = tl.program_id(0) - pid_bh = tl.program_id(1) + pid_bh = tl.program_id(0) + pid_m = tl.program_id(1) # Compute batch and head indices off_z = pid_bh // H @@ -62,11 +70,18 @@ def _bwd_preprocess_use_o( do_ptrs = do_offset + off_m[:, None] * stride_dom + off_d[None, :] * stride_dok # load - o = tl.load(out_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32) - do = tl.load(do_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32) + o = tl.load(out_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0) + do = tl.load(do_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0) # compute delta - delta = tl.sum(o * do, axis=1) + if IS_FP8: + stride_descale_q_z = H + descale_do = tl.load(DESCALE_do + off_z * stride_descale_q_z + off_h) + + # NOTE: do is scaled into the fp8 range and o is in fp8 but should be in the same scale as fp32 + delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) + else: + delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) # write-back delta delta_offset = Delta + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam @@ -94,8 +109,9 @@ def _bwd_kernel_one_col_block( dq_offset, dk_offset, dv_offset, - d_offset, l_offset, + delta_offset, + dropout_offset, stride_dq_all, stride_qz, stride_qh, @@ -112,23 +128,30 @@ def _bwd_kernel_one_col_block( stride_deltaz, stride_deltah, stride_deltam, - Z, - H, + stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, N_CTX_Q, N_CTX_K, - off_h, - off_z, - off_hz, start_n, num_block_m, num_block_n, + dropout_p, + philox_seed, + batch_philox_offset, + descale_q, + descale_k, + descale_v, + descale_do, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, CAUSAL: tl.constexpr, + DROPOUT: tl.constexpr, USE_EXP2: tl.constexpr, + GROUP_SIZE: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, ): if CAUSAL: # TODO: Causal can skip more blocks with something like lo = start_m * BLOCK_M @@ -154,11 +177,12 @@ def _bwd_kernel_one_col_block( k_ptrs = k_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk v_ptrs = v_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk k = tl.load(k_ptrs, mask=kv_mask, other=0.0) - v = tl.load(v_ptrs, mask=kv_mask, other=0.0) + kT = tl.trans(k) + vT = tl.trans(tl.load(v_ptrs, mask=kv_mask, other=0.0)) # loop over rows - for start_m in range(lo, num_block_m * BLOCK_M, BLOCK_M): - offs_m = start_m + tl.arange(0, BLOCK_M) + for start_m in range(lo, num_block_m): + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk dq_ptrs = dq_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk do_ptrs = do_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk @@ -173,7 +197,10 @@ def _bwd_kernel_one_col_block( # recompute p = softmax(qk, dim=-1).T qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, tl.trans(k)) + if IS_FP8: + qk += (tl.dot(q, kT) * descale_q * descale_k) + else: + qk += tl.dot(q, kT) if CAUSAL: col_offset = N_CTX_Q - N_CTX_K @@ -197,27 +224,89 @@ def _bwd_kernel_one_col_block( p_mask = mask_m[:, None] & mask_n[None, :] p = tl.where(p_mask, p, 0.0) - # compute dv - dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) + if DROPOUT: + # NOTE: must create a new var p_drop to prevent p (which is used later to compute ds) from changing + philox_offset = batch_philox_offset + offs_m[:, None] * stride_dropoutm + offs_n[None, :] * stride_dropoutn + # print("philox_seed:", philox_seed) + # print("philox_offset:", philox_offset) + if tl_DROPOUT_USE_PYTORCH: + dropout_ptrs = dropout_offset + offs_m[:, None] * stride_dropoutm + offs_n[None, :] * stride_dropoutn + dropout_mask = tl.load(dropout_ptrs, mask=p_mask) + else: + rand_vals = tl.rand(philox_seed, philox_offset) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1/ (1 - dropout_p) + + if tl_DROPOUT_DUMP: + dropout_ptrs = dropout_offset + offs_m[:, None] * stride_dropoutm + offs_n[None, :] * stride_dropoutn + tl.store(dropout_ptrs, dropout_mask, mask=p_mask) + + # apply dropout mask + p_drop = tl.where(dropout_mask, p, 0.0) + p_drop_scaled = p_drop * dropout_scale + + # compute dv + if IS_FP8: + scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors(p_drop_scaled, FP8_MAX) + dv += (tl.dot(tl.trans(p_drop_scaled * scale_p_dropout).to(do.type.element_ty), do) * descale_p_dropout * descale_do) + else: + dv += tl.dot(tl.trans(p_drop_scaled).to(do.type.element_ty), do) + + # compute dp + if IS_FP8: + dp_drop_scaled = (tl.dot(do, vT) * descale_do * descale_v) + else: + dp_drop_scaled = tl.dot(do, vT) + dp = tl.where(dropout_mask, dp_drop_scaled, 0.0) * dropout_scale + else: + + # compute dv + if IS_FP8: + scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) + dv += (tl.dot(tl.trans(p * scale_p).to(do.type.element_ty), do) * descale_p * descale_do) + else: + dv += tl.dot(tl.trans(p).to(do.type.element_ty), do) - # compute dp - dp = tl.dot(do, tl.trans(v)) + # compute dp + if IS_FP8: + dp = (tl.dot(do, vT) * descale_do * descale_v) + else: + dp = tl.dot(do, vT) - # compute ds , ds = p * (dp - delta[:, None]) - d_ptrs = d_offset + offs_m * stride_deltam - Di = tl.load(d_ptrs, mask=mask_m) - ds = (p * (dp - Di[:, None])) * sm_scale - ds = tl.where(p_mask, ds, 0.0).to(Q.dtype.element_ty) - # compute dk = dot(ds.T, q) - dk += tl.dot(tl.trans(ds), q) + # load delta + delta_ptrs = delta_offset + offs_m * stride_deltam + delta_i = tl.load(delta_ptrs, mask=mask_m) + + # compute ds + dscores_scaled = (p * (dp - delta_i[:, None])) + ds = dscores_scaled * sm_scale + ds = tl.where(p_mask, ds, 0.0) + + # compute descale_ds + if IS_FP8: + scale_ds, descale_ds = compute_fp8_scaling_factors(ds, FP8_MAX) + else: + scale_ds, descale_ds = 1.0, 1.0 + + # compute dk + if IS_FP8: + dk += (tl.dot(tl.trans(ds * scale_ds).to(q.type.element_ty), q) * descale_ds * descale_q) + else: + dk += tl.dot(tl.trans(ds).to(q.type.element_ty), q) # compute dq if SEQUENCE_PARALLEL: - dq = tl.dot(ds, k) + if IS_FP8: + dq = (tl.dot((ds * scale_ds).to(k.type.element_ty), k) * descale_ds * descale_k) + else: + dq = tl.dot(ds.to(k.type.element_ty), k) else: dq = tl.load(dq_ptrs, mask=q_mask, other=0.0) - dq += tl.dot(ds, k) + if IS_FP8: + dq += (tl.dot((ds * scale_ds).to(k.type.element_ty), k) * descale_ds * descale_k) + else: + dq += tl.dot(ds.to(k.type.element_ty), k) tl.store(dq_ptrs, dq.to(Q.dtype.element_ty), mask=q_mask) # write-back dv and dk @@ -225,8 +314,13 @@ def _bwd_kernel_one_col_block( dv_ptrs = dv_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk # write-back - tl.store(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask) - tl.store(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask) + if GROUP_SIZE != 1: + # use atomic_add to properly accumulate gradients from multiple query heads + tl.atomic_add(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask) + tl.atomic_add(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask) + else: + tl.store(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask) + tl.store(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask) @triton.jit def _bwd_kernel( @@ -240,7 +334,12 @@ def _bwd_kernel( DK, DV, L, - D, + Delta, + Dropout_mask, + DESCALE_q, + DESCALE_k, + DESCALE_v, + DESCALE_do, stride_dq_all, stride_qz, stride_qh, @@ -257,29 +356,44 @@ def _bwd_kernel( stride_deltaz, stride_deltah, stride_deltam, + stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, Z, - H, + HQ, + HK, num_block_m, num_block_n, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, + philox_seed, + philox_offset_base, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, CAUSAL: tl.constexpr, + DROPOUT: tl.constexpr, USE_EXP2: tl.constexpr, IS_VARLEN: tl.constexpr, + GROUP_SIZE: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, ): # program ids - off_hz = tl.program_id(0) + off_zh = tl.program_id(0) if SEQUENCE_PARALLEL: start_n = tl.program_id(1) - off_z = off_hz // H - off_h = off_hz % H + off_z = off_zh // HQ + off_hq = off_zh % HQ + + # check if GQA/MQA + if GROUP_SIZE != 1: + off_hk = off_hq // GROUP_SIZE + else: + off_hk = off_hq if IS_VARLEN: # Compute sequence lengths for the current batch @@ -296,23 +410,40 @@ def _bwd_kernel( k_start = 0 N_CTX_Q = max_seqlen_q N_CTX_K = max_seqlen_k - # input tensor offsets - q_offset = Q + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm - k_offset = K + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn - v_offset = V + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn - do_offset = DO + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm - l_offset = L + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam - d_offset = D + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam + q_offset = Q + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm + k_offset = K + off_z * stride_kz + off_hk * stride_kh + k_start * stride_kn + v_offset = V + off_z * stride_vz + off_hk * stride_vh + k_start * stride_vn + do_offset = DO + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm + l_offset = L + off_z * stride_deltaz + off_hq * stride_deltah + q_start * stride_deltam + delta_offset = Delta + off_z * stride_deltaz + off_hq * stride_deltah + q_start * stride_deltam + + if DROPOUT: + batch_philox_offset = philox_offset_base + off_z * stride_dropoutz + off_hq * stride_dropouth #+ q_start * stride_dropoutm + dropout_offset = Dropout_mask + off_z * stride_dropoutz + off_hq * stride_dropouth #+ q_start * stride_dropoutm + else: + batch_philox_offset = 0 + dropout_offset = 0 + + if IS_FP8: + stride_descale_q_z = HQ + stride_descale_kv_z = HK + descale_q = tl.load(DESCALE_q + off_z * stride_descale_q_z + off_hq) + descale_k = tl.load(DESCALE_k + off_z * stride_descale_kv_z + off_hk) + descale_v = tl.load(DESCALE_v + off_z * stride_descale_kv_z + off_hk) + descale_do = tl.load(DESCALE_do + off_z * stride_descale_q_z + off_hq) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + # output tensor offsets - dk_offset = DK + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn - dv_offset = DV + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn + dk_offset = DK + off_z * stride_kz + off_hk * stride_kh + k_start * stride_kn + dv_offset = DV + off_z * stride_vz + off_hk * stride_vh + k_start * stride_vn if SEQUENCE_PARALLEL: - dq_offset = DQ + start_n * stride_dq_all + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm + dq_offset = DQ + start_n * stride_dq_all + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm else: - dq_offset = DQ + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm + dq_offset = DQ + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm # inner loop if SEQUENCE_PARALLEL: @@ -327,7 +458,7 @@ def _bwd_kernel( DK, DV, L, - D, + Delta, q_offset, k_offset, v_offset, @@ -335,8 +466,9 @@ def _bwd_kernel( dq_offset, dk_offset, dv_offset, - d_offset, l_offset, + delta_offset, + dropout_offset, stride_dq_all, stride_qz, stride_qh, @@ -350,26 +482,33 @@ def _bwd_kernel( stride_vh, stride_vn, stride_vk, - stride_deltaz, - stride_deltah, + stride_deltaz, + stride_deltah, stride_deltam, - Z, - H, + stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, N_CTX_Q, N_CTX_K, - off_h, - off_z, - off_hz, start_n, num_block_m, num_block_n, + dropout_p, + philox_seed, + batch_philox_offset, + descale_q, + descale_k, + descale_v, + descale_do, BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, BLOCK_N=BLOCK_N, SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, CAUSAL=CAUSAL, + DROPOUT=DROPOUT, USE_EXP2=USE_EXP2, + GROUP_SIZE=GROUP_SIZE, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX ) else: for start_n in range(0, num_block_n): @@ -384,7 +523,7 @@ def _bwd_kernel( DK, DV, L, - D, + Delta, q_offset, k_offset, v_offset, @@ -392,8 +531,9 @@ def _bwd_kernel( dq_offset, dk_offset, dv_offset, - d_offset, l_offset, + delta_offset, + dropout_offset, stride_dq_all, stride_qz, stride_qh, @@ -407,54 +547,69 @@ def _bwd_kernel( stride_vh, stride_vn, stride_vk, - stride_deltaz, - stride_deltah, + stride_deltaz, + stride_deltah, stride_deltam, - Z, - H, + stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, N_CTX_Q, N_CTX_K, - off_h, - off_z, - off_hz, start_n, num_block_m, num_block_n, + dropout_p, + philox_seed, + batch_philox_offset, + descale_q, + descale_k, + descale_v, + descale_do, BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, BLOCK_N=BLOCK_N, SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, CAUSAL=CAUSAL, + DROPOUT=DROPOUT, USE_EXP2=USE_EXP2, + GROUP_SIZE=GROUP_SIZE, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX ) -# NOTE: smaller blocks have lower accuracy. more accumlation error probably 128 * 128 seems good but leads to oom. 64 * 64 has accumlation errors but no oom. +# NOTE: smaller blocks have lower accuracy. more accumulation error probably 128 * 128 seems good but leads to oom. 64 * 64 has accumulation errors but no oom. def attention_prefill_backward_triton_impl( - do, - q, - k, - v, - o, - softmax_lse, - dq, - dk, - dv, + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + softmax_lse: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, sm_scale: float, - alibi_slopes, - causal, - layout: str, - cu_seqlens_q, - cu_seqlens_k, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + layout: Literal["bshd", "bhsd", "thd"], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], max_seqlen_q: int, max_seqlen_k: int, + dropout_p: float, + philox_seed: Optional[int], + philox_offset: Optional[int], use_exp2: bool, - sequence_parallel = True, + sequence_parallel: bool = True, + # fp8 + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_do: Optional[torch.Tensor] = None, ): if DEBUG: print() - print("attention_prefill_backward_triton_new_impl") + print("attention_prefill_backward_triton_impl") print("do:", do, do.shape) print("q:", q, q.shape) print("k:", k, k.shape) @@ -472,24 +627,38 @@ def attention_prefill_backward_triton_impl( print("cu_seqlens_k:", cu_seqlens_k) print("max_seqlen_q:", max_seqlen_q) print("max_seqlen_k:", max_seqlen_k) + print("dropout_p:", dropout_p) + print("philox_seed:", philox_seed) + print("philox_offset:", philox_offset) print("use_exp2:", use_exp2) print("sequence_parallel:", sequence_parallel) + print("descale_q:", descale_q) + print("descale_k:", descale_k) + print("descale_v:", descale_v) + print("descale_do:", descale_do) + + IS_FP8 = is_fp8(q) + if IS_FP8: + FP8_MAX=torch.finfo(q.dtype).max + else: + FP8_MAX=None - # make contigious + # make contiguous q = q.contiguous() k = k.contiguous() v = v.contiguous() softmax_lse = softmax_lse.contiguous() # get strides and shape - batch, nheads_q, nheads_k, head_size, max_seqlen_q, max_seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) + batch, nheads_q, nheads_k, head_size, max_seqlen_q, max_seqlen_k = get_shapes_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout) stride_qz, stride_qh, stride_qm, stride_qk = q_strides stride_kz, stride_kh, stride_kn, stride_kk = k_strides stride_vz, stride_vh, stride_vn, stride_vk = v_strides stride_oz, stride_oh, stride_om, stride_ok = o_strides - batch_headsize = batch * nheads_q is_varlen = layout == "thd" + group_size = nheads_q // nheads_k + use_dropout = (dropout_p > 0.0) # FIXME: some configs lead to oom for some reason when using 64 x 64 blocks if max_seqlen_q <= 32 or max_seqlen_k <= 32: @@ -498,7 +667,12 @@ def attention_prefill_backward_triton_impl( else: BLOCK_M = 64 BLOCK_N = 64 - num_warps = 4 # NOTE: originial is 8. changing it to 1 caused issues be careful + + if DEBUG: + print("BLOCK_M:", BLOCK_M) + print("BLOCK_N:", BLOCK_N) + + num_warps = 4 # NOTE: original is 8. changing it to 1 caused issues be careful num_stages = 1 waves_per_eu = 1 @@ -513,48 +687,13 @@ def attention_prefill_backward_triton_impl( ACTUAL_BLOCK_DMODEL = head_size do = do.contiguous() - # NOTE: we might need to copy the output tensor if they are not continuous or have other issues - copy_back = {"dq": False, "dk": False, "dv": False} # deal with dq - if dq is None: - if sequence_parallel: - dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype) - else: - dq = torch.zeros(q.shape, device=q.device, dtype=q.dtype) - else: - dq_og = dq - if (not dq.is_contiguous()): - dq = dq.contiguous() - copy_back["dq"] = True - - if sequence_parallel: - dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype) - copy_back["dq"] = True - else: - # NOTE: the kernel does inplace accumlation so dq has to be zeros. This avoids the case where we are passed empty dq and it is not all zeros - dq.zero_() + if sequence_parallel: + dq = dq.unsqueeze(0).repeat(num_blocks_n, *([1] * len(q.shape))) # we do repeat instead of expand because we need to write data so views are not enough stride_dq_all = dq.stride()[0] - # deal with dk, dv - if (dk is None) or (dv is None): - dk = torch.empty_like(k) - dv = torch.empty_like(v) - else: - if (not dk.is_contiguous()): - dk_og = dk - dk = dk.contiguous() - copy_back["dk"] = True - - if (not dv.is_contiguous()): - dv_og = dv - dv = dv.contiguous() - copy_back["dv"] = True - - if DEBUG: - print("copy_back:", copy_back) - - # assert contigious + # assert contiguous assert do.is_contiguous() assert q.is_contiguous() assert k.is_contiguous() @@ -563,66 +702,53 @@ def attention_prefill_backward_triton_impl( assert softmax_lse.is_contiguous() # init delta - delta = torch.empty_like(softmax_lse) + delta = torch.zeros_like(softmax_lse) if is_varlen: stride_deltam, stride_deltah = delta.stride() stride_deltaz = 0 else: stride_deltaz, stride_deltah, stride_deltam = delta.stride() - _bwd_preprocess_use_o[(num_blocks_m, batch_headsize)]( + # dropout mask tensor for debugging. We dump the dropout mask created in the kernel for testing + if use_dropout: + if DROPOUT_USE_PYTORCH: + dropout_mask = create_dropout_mask(dropout_p, (batch, nheads_q, max_seqlen_q, max_seqlen_k), seed = philox_seed) + else: + dropout_mask = torch.zeros((batch, nheads_q, max_seqlen_q, max_seqlen_k), device=q.device, + dtype=torch.float32) + stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn = (dropout_mask.stride(0), dropout_mask.stride(1), dropout_mask.stride(2), dropout_mask.stride(3)) + else: + dropout_mask = None + stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn = (0, 0 , 0 , 0) + + + _bwd_preprocess[(batch * nheads_q, num_blocks_m)]( o, do, delta, stride_oz, stride_oh, stride_om, stride_ok, - stride_oz, stride_oh, stride_om, stride_ok, + stride_oz, stride_oh, stride_om, stride_ok, # FIXME: don't share strides with derivatives this was causing a lot of issues stride_deltaz, stride_deltah, stride_deltam, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + descale_do, BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, N_CTX_Q=max_seqlen_q, Z=batch, H=nheads_q, - IS_VARLEN=is_varlen + IS_VARLEN=is_varlen, + IS_FP8=IS_FP8 ) if DEBUG: - print("_bwd_kernel inputs") - print("do:", do, do.shape) - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("sm_scale", sm_scale) - print("o:", o, o.shape) - print("dq:", dq, dq.shape) - print("dk:", dk, dk.shape) - print("dv:", dv, dv.shape) - print("L:", softmax_lse, softmax_lse.shape) print("delta:", delta, delta.shape) - print("stride_qz, stride_qh, stride_qm, stride_qk:", stride_qz, stride_qh, stride_qm, stride_qk) - print("stride_kz, stride_kh, stride_kn, stride_kk:", stride_kz, stride_kh, stride_kn, stride_kk) - print("stride_vz, stride_vh, stride_vn, stride_vk:", stride_vz, stride_vh, stride_vn, stride_vk) - print("batch_q:", batch) - print("heads_q:",nheads_q) - print("max_seqlen_q:",max_seqlen_q) - print("max_seqlen_k:",max_seqlen_k) - print("BLOCK_M:",BLOCK_M) - print("BLOCK_N:",BLOCK_M) - print("BLOCK_DMODEL:",BLOCK_DMODEL) - print("ACTUAL_BLOCK_DMODEL:",ACTUAL_BLOCK_DMODEL) - print("SEQUENCE_PARALLEL:",sequence_parallel) - print("CAUSAL:",causal) - print("num_warps:",num_warps) - print("num_stages:", num_stages) - print("USE_EXP2:", use_exp2) - print("num_blocks_m:", num_blocks_m) - print("num_blocks_n:", num_blocks_n) - - _bwd_kernel[(batch_headsize, num_blocks_n if sequence_parallel else 1)]( + print("group_size:", group_size) + + _bwd_kernel[(batch * nheads_q, num_blocks_n if sequence_parallel else 1)]( q, k, v, @@ -634,58 +760,55 @@ def attention_prefill_backward_triton_impl( dv, softmax_lse, delta, + dropout_mask, + descale_q, + descale_k, + descale_v, + descale_do, stride_dq_all, - stride_qz, stride_qh, stride_qm, stride_qk, + stride_qz, stride_qh, stride_qm, stride_qk, # FIXME: don't share strides with derivatives this was causing a lot of issues stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vn, stride_vk, stride_deltaz, stride_deltah, stride_deltam, + stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, batch, nheads_q, + nheads_k, num_blocks_m, num_blocks_n, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, philox_seed, philox_offset, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=BLOCK_DMODEL, ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, SEQUENCE_PARALLEL=sequence_parallel, CAUSAL=causal, + DROPOUT=use_dropout, USE_EXP2=use_exp2, num_warps=num_warps, num_stages=num_stages, waves_per_eu = waves_per_eu, - IS_VARLEN=is_varlen + IS_VARLEN=is_varlen, + GROUP_SIZE=group_size, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX ) - if DEBUG: - print("_bwd_kernel outputs") - print("dq:", dq, dq.shape) - print("dk:", dk, dk.shape) - print("dv:", dv, dv.shape) - print("delta:", delta, delta.shape) - if sequence_parallel: dq = dq.sum(dim=0) if DEBUG: - print("attention_prefill_backward_triton_new_impl outputs") - print("dq:", dq, dq.shape) - print("dk:", dk, dk.shape) + print("attention_prefill_backward_triton_impl outputs") print("dv:", dv, dv.shape) - print("delta:", delta, delta.shape) - print("copy_back:", copy_back) - - if copy_back["dq"]: - dq_og.copy_(dq) - dq = dq_og - if copy_back["dk"]: - dk_og.copy_(dk) - dk = dk_og - if copy_back["dv"]: - dv_og.copy_(dv) - dv = dv_og - - return dq, dk, dv, delta, None, None + print("dk:", dk, dk.shape) + print("dq:", dq, dq.shape) + if use_dropout: + print("dropout_mask:", dropout_mask, dropout_mask.shape if dropout_mask is not None else None) + print("dropout_fraction bwd:", 1.0 - (dropout_mask.sum()/ dropout_mask.numel()).item()) + write_dropout_mask(dropout_mask, "dropout_mask_bwd") + + return delta diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused.py new file mode 100644 index 00000000000..3c018be4fa0 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused.py @@ -0,0 +1,3266 @@ +import torch +import triton +import triton.language as tl + +from typing import Optional, Tuple + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + +@triton.jit +def compute_fp8_scaling_factors(x, fp8_max: tl.constexpr): + # compute fp8 scaling and descaling factor for a block + x_amax = tl.max(tl.abs(x)) # NOTE: abs deals with negative values + x_amax = tl.where(x_amax <= 1e-9, 1e-9, x_amax) + scale_x = fp8_max / x_amax + descale_x = x_amax / fp8_max + return scale_x, descale_x + +def is_fp8(x): + if x.dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz}: + if arch_supports_fp8(): + return True + else: + raise RuntimeError("This device does not support fp8") + else: + return False + + +def cast_to_fp8( + x: torch.Tensor, + fp8_dtype, + layout, + clamp_val=1e-9, +): + if len(x.shape) != 4: + raise ValueError(f"'bshd' tensor should have shape [batch, seqlen, heads, dim], got {x.shape}") + reduce_dims = (1, 3) # seq_len and dim dimensions + + # Compute the absolute max along reduce_dims, clamped to avoid 0-scale + x_abs_max = x.abs().amax(dim=reduce_dims) + x_abs_max = torch.maximum(x_abs_max, x.new_tensor(clamp_val)) + + # Unsqueeze back to a shape suitable for broadcast + unsqueeze_dims = sorted(reduce_dims) + for d in unsqueeze_dims: + x_abs_max = x_abs_max.unsqueeze(d) + + # compute scale and descale + fp8_max = torch.finfo(fp8_dtype).max + scale = fp8_max / x_abs_max + descale_factor = x_abs_max / fp8_max + + # cast to FP8, optionally setting requires_grad + x_fp8 = (x * scale).to(fp8_dtype) + + return x_fp8, descale_factor + + +def cast_varlen_to_fp8( + x: torch.Tensor, + fp8_dtype: torch.dtype, + cu_seqlens, + clamp_val: float = 1e-9, +) -> tuple[torch.Tensor, torch.Tensor]: + # validate tensor shape + if len(x.shape) != 3: + raise ValueError(f"tensor should have shape [total_seqlen, heads, dim], got {x.shape}") + num_heads = x.shape[1] + + # Get batch size from cu_seqlens + batch = cu_seqlens.shape[0] - 1 + fp8_max = torch.finfo(fp8_dtype).max + + # Compute scale and descale factors per sequence + x_fp8 = torch.zeros_like(x, dtype=fp8_dtype) + descale_factors = torch.zeros((batch, num_heads), device=x.device, dtype=torch.float32) + + for i in range(batch): + start = cu_seqlens[i] + end = cu_seqlens[i + 1] + x_slice = x[start:end] # Slice for current sequence + + # Standard tensor (0: seq_len, 2: head_dim) + x_abs_max = x_slice.abs().amax(dim=(0, 2)) # [heads] + + # apply minimum clamping + x_abs_max = torch.maximum(x_abs_max, x.new_tensor(clamp_val)) + + # compute scale and descale factors + scale_i = fp8_max / x_abs_max + descale_i = x_abs_max / fp8_max + + # store descale factors + descale_factors[i, :] = descale_i + + scale_reshape = scale_i.reshape(1, num_heads, 1) + + # scale and cast to FP8 + x_fp8[start:end] = (x_slice * scale_reshape).to(fp8_dtype) + + return x_fp8, descale_factors + + +#TODO Move this to a common folder. Will need to add future arch list +def get_arch(): + return triton.runtime.driver.active.get_current_target().arch + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + +def arch_supports_fp8(): + return is_hip() and get_arch() in ('gfx942') + +@triton.jit +def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): + if offset_first is not None and offset_second is not None: + mask = (offset_first[:, None] < boundary_first) & \ + (offset_second[None, :] < boundary_second) + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_first is not None: + mask = offset_first[:, None] < boundary_first + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_second is not None: + mask = offset_second[None, :] < boundary_second + tensor = tl.load(ptrs, mask=mask, other=0.0) + else: + tensor = tl.load(ptrs) + return tensor + +@triton.jit +def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False): + # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix + # for casual mask we want something like this where (1 is kept and 0 is masked) + # seqlen_q = 2 and seqlen_k = 5 + # 1 1 1 1 0 + # 1 1 1 1 1 + # seqlen_q = 5 and seqlen_k = 2 + # 0 0 + # 0 0 + # 0 0 + # 1 0 + # 1 1 + # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal + # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False + # 1. offs_m[:,None] = [[0], + # [1], + # 2. offs_m[:,None] + seqlen_k = [[5], + # [6], + # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], + # [4], + # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], + # [4], [ 4, 3, 2, 1, 0]] + # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], + # [ -4, -3, -2, -1, 0]], + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + if transpose: + return alibi_block.T + else: + return alibi_block + +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + stride_kn, + stride_vk, + stride_sn, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + sd_mask_ptrs, + dropout_mask_ptrs, + philox_seed, + philox_ptrs, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + alibi_slope, + descale_q, + descale_k, + descale_v, + OFFS_M: tl.constexpr, + OFFS_N: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL_POW2: tl.constexpr, + SM_SCALE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_SCORES: tl.constexpr, + PADDED_HEAD: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + RCP_LN2: tl.constexpr = 1.4426950408889634 + + # loop over k, v, and update accumulator + + for start_n in range(block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + if MASK_STEPS: + k_offs_n = start_n + tl.arange(0, BLOCK_N) + else: + k_offs_n = None + k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL_POW2) + k = load_fn(k_ptrs, k_offs_k, k_offs_n, BLOCK_DMODEL, seqlen_k) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. + # last step might get wasted but that is okay. check if this masking works For + # that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], seqlen_k, dtype=tl.int32) + size_n = start_n + OFFS_N[None, :] + mask = size_n < boundary_m[:, None] + qk = tl.where(mask, qk, float("-inf")) + + # compute masks + q_mask = (OFFS_M[:, None] < seqlen_q) + k_mask = ((start_n + tl.arange(0, BLOCK_N))[None, :] < seqlen_k) + p_mask = q_mask & k_mask + + # -- compute qk ---- + if IS_FP8: + qk += (tl.dot(q, k) * descale_q * descale_k) + else: + qk += tl.dot(q, k) + qk_scaled = qk * SM_SCALE + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk_scaled = tl.where(causal_mask, qk_scaled, float("-inf")) + + if alibi_slope is not None: + # Compute the global position of each token within the sequence + global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + global_n_positions = start_n + tl.arange(0, BLOCK_N) + alibi_block = compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, global_m_positions, + global_n_positions) + qk_scaled += alibi_block + # get max scores so far + m_ij = tl.maximum(m_i, tl.max(qk_scaled, 1)) + + # scale and subtract max + q_shifted = qk_scaled - m_ij[:, None] + + # Compute scaled QK and softmax probabilities + p = tl.math.exp2(q_shifted * RCP_LN2) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + rng_output = tl.rand(philox_seed, philox_ptrs) # TODO: use tl.randint for better performance + dropout_mask = rng_output > dropout_p + tl.store(dropout_mask_ptrs, dropout_mask, mask=p_mask) + + # return scores with negative values for dropped vals + sd_mask = tl.where(dropout_mask, p, -p) + tl.store(sd_mask_ptrs, sd_mask, mask=p_mask) + + # apply dropout mask in place + p = tl.where(dropout_mask, p, 0.0) + elif RETURN_SCORES: + # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that + tl.store(sd_mask_ptrs, p, mask=p_mask) + + # -- update output accumulator -- + # alpha is an adjustment factor for acc and li as we loop and find new maxes + # store the diff in maxes to adjust acc and li as we discover new maxes + m_diff = m_i - m_ij + alpha = tl.math.exp2(m_diff * RCP_LN2) + acc = acc * alpha[:, None] + v = load_fn(v_ptrs, k_offs_n, k_offs_k, seqlen_k, BLOCK_DMODEL) + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + + if IS_FP8: + scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) + acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * descale_v) + else: + acc += tl.dot(p.to(v.type.element_ty), v) + + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + if RETURN_SCORES: + sd_mask_ptrs += BLOCK_N * stride_sn + + if ENABLE_DROPOUT: + dropout_mask_ptrs += BLOCK_N * stride_sn + philox_ptrs += BLOCK_N * stride_sn + + return acc, l_i, m_i + + +@triton.jit +def _attn_fwd(q_ptr: torch.Tensor, + k_ptr: torch.Tensor, + v_ptr: torch.Tensor, + descale_q_ptr: torch.Tensor, + descale_k_ptr: torch.Tensor, + descale_v_ptr: torch.Tensor, + out_ptr: torch.Tensor, + alibi_slopes_ptr: torch.Tensor, + s_dmask_ptr: torch.Tensor, + dropout_mask_ptr: torch.Tensor, + softmax_lse_ptr: torch.Tensor, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vn, stride_vk, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, + stride_oz, stride_oh, stride_om, stride_on, + stride_alibi_z, stride_alibi_h, + stride_sd_z, stride_sd_h, stride_sd_m, stride_sd_n, + stride_lse_z, stride_lse_h, stride_lse_m, + sm_scale, + cu_seqlens_q, + cu_seqlens_k, + dropout_p, + philox_seed, + philox_offset, + SEQLEN_Q: tl.constexpr, + SEQLEN_K: tl.constexpr, + IS_CAUSAL: tl.constexpr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL_POW2: tl.constexpr, + RETURN_SCORES: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + VARLEN: tl.constexpr, +): + #calculate offsets + start_m = tl.program_id(0) #seqlen_q + off_q_head = tl.program_id(1) #num_q_heads + off_z = tl.program_id(2) #batch + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL_POW2) + + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = SEQLEN_Q + seqlen_k = SEQLEN_K + + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + if (IS_CAUSAL): + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + + # This captures the decrease in n_blocks if we have a rectangular attn matrix + n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + + # If we have no blocks after adjusting for seqlen deltas, this WG is part of + # the blocks that are all 0. We exit early. + if n_blocks <= 0: + offs_out = (off_z * stride_oz + + off_q_head * stride_oh + + cu_seqlens_q_start * stride_om + + offs_m[:, None] * stride_om + + offs_d[None, :] * stride_on) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_POW2], dtype=out_ptr.type.element_ty) + out_mask = (offs_m[:, None] < seqlen_q) & (offs_d < BLOCK_DMODEL) + tl.store(out_ptr + offs_out, acc, mask=out_mask) + + if softmax_lse_ptr is not None: + offs_lse = (off_z * stride_lse_z + + off_q_head * stride_lse_h + + cu_seqlens_q_start * stride_lse_m + + offs_m*stride_lse_m + ) + lse_mask = offs_m < SEQLEN_Q + lse = tl.full([BLOCK_M], value=0.0, dtype=tl.float32) + tl.store(softmax_lse_ptr + offs_lse, lse, mask=lse_mask) + # TODO: Should dropout and return encoded softmax be handled here too? + + return + + grp_sz:tl.constexpr = NUM_Q_HEADS // NUM_K_HEADS + if grp_sz != 1: #Grouped Query Attention + off_k_head = off_q_head // grp_sz + else: + off_k_head = off_q_head + + #q,k,v offsets + q_offs = (off_z * stride_qz + + off_q_head * stride_qh + + cu_seqlens_q_start * stride_qm + + offs_m[:, None] * stride_qm + offs_d[None, :]*stride_qk + ) + q_ptrs = q_ptr + q_offs + + k_offs = (off_z * stride_kz + + off_k_head * stride_kh + + cu_seqlens_k_start * stride_kn + + offs_d[:, None] * stride_kk + offs_n[None, :]*stride_kn + ) + k_ptrs = k_ptr + k_offs + + v_offs = (off_z * stride_vz + + off_k_head * stride_vh + + cu_seqlens_k_start * stride_vn + + offs_n[:, None] * stride_vn + offs_d[None, :]*stride_vk + ) + v_ptrs = v_ptr + v_offs + + #alibi slopes + if alibi_slopes_ptr is not None: + alibi_offs = off_z * stride_alibi_z + off_q_head * stride_alibi_h + alibi_slope = tl.load(alibi_slopes + alibi_offs) + else: + alibi_slope = None + + #s_dmask (return_scores) + if s_dmask_ptr is not None: + s_dmask_offs = (off_z * stride_sd_z + + off_q_head * stride_sd_h + + offs_m[:, None] * stride_sd_m + + offs_n[None, :] * stride_sd_n + ) + s_dmask_ptrs = s_dmask_ptr + s_dmask_offs + else: + s_dmask_ptrs = None + + #dropout + if dropout_mask_ptr is not None: + dropout_mask_offs = (off_z * stride_sd_z + + off_q_head * stride_sd_h + + offs_m[:, None] * stride_sd_m + + offs_n[None, :] * stride_sd_n + ) + dropout_mask_ptrs = dropout_mask_ptr + dropout_mask_offs + philox_ptrs = (philox_offset + + off_z * stride_sd_z + + off_q_head * stride_sd_h + + offs_m[:, None] * stride_sd_m + + offs_n[None, :] * stride_sd_n + ) + else: + dropout_mask_ptrs = None + philox_ptrs = None + + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_POW2], dtype=tl.float32) + if (BLOCK_DMODEL == BLOCK_DMODEL_POW2): + q_mask = (offs_m[:, None] < seqlen_q) + else: + q_mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < BLOCK_DMODEL) + q = tl.load(q_ptrs, mask=q_mask, other=0.0) + if IS_FP8: + descale_q = tl.load(descale_q_ptr + off_z * stride_descale_q_z + off_q_head) + descale_k = tl.load(descale_k_ptr + off_z * stride_descale_k_z + off_k_head) + descale_v = tl.load(descale_v_ptr + off_z * stride_descale_v_z + off_k_head) + else: + descale_q, descale_k ,descale_v = 1.0, 1.0, 1.0 + + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N -seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + + #if CAUSAL, then determine masked_blocks and full blocks + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. + # In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its actual + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner(acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + stride_kn, + stride_vn, + stride_sd_n, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + s_dmask_ptrs, dropout_mask_ptrs, philox_seed, philox_ptrs, + block_min, block_max, 0, 0, 0, alibi_slope, + descale_q, descale_k, descale_v, + offs_m, offs_n, BLOCK_M, BLOCK_N, BLOCK_DMODEL,BLOCK_DMODEL_POW2, + sm_scale, False, MASK_STEPS=False, ENABLE_DROPOUT=ENABLE_DROPOUT, + RETURN_SCORES=RETURN_SCORES, PADDED_HEAD=BLOCK_DMODEL!=BLOCK_DMODEL_POW2, + IS_FP8=IS_FP8, FP8_MAX=FP8_MAX + ) + block_min = block_max + block_max = n_blocks * BLOCK_N + + # Remaining blocks, if any, are full / not masked. + if (masked_blocks > 0): + if IS_CAUSAL: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) + else: + offs_n_causal = 0 + k_ptrs += n_full_blocks * BLOCK_N * stride_kn + v_ptrs += n_full_blocks * BLOCK_N * stride_vn + if RETURN_SCORES: + s_dmask_ptrs += n_full_blocks * BLOCK_N * stride_sd_n + if ENABLE_DROPOUT: + dropout_mask_ptrs += n_full_blocks * BLOCK_N * stride_sd_n + acc, l_i, m_i = _attn_fwd_inner(acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + stride_kn, stride_vn, stride_sd_n, + start_m, seqlen_k, seqlen_q, + dropout_p, + s_dmask_ptrs, dropout_mask_ptrs, philox_seed, philox_ptrs, + block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, + descale_q, descale_k, descale_v, + offs_m, offs_n, BLOCK_M, BLOCK_N, BLOCK_DMODEL,BLOCK_DMODEL_POW2, + sm_scale, IS_CAUSAL, MASK_STEPS=True, ENABLE_DROPOUT=ENABLE_DROPOUT, + RETURN_SCORES=RETURN_SCORES, PADDED_HEAD=BLOCK_DMODEL!=BLOCK_DMODEL_POW2, + IS_FP8=IS_FP8, FP8_MAX=FP8_MAX + ) + # epilogue + # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + if ENABLE_DROPOUT: + dropout_scale = 1 / (1 - dropout_p) + acc = acc * dropout_scale + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + if IS_CAUSAL: + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL_POW2, ), causal_start_idx, dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + + # write back LSE(Log Sum Exponents), the log of the normalization constant + overflow_size = end_m_idx - seqlen_q + if softmax_lse_ptr is not None: + RCP_LN2: tl.constexpr = 1.4426950408889634 + LN2: tl.constexpr = 0.6931471824645996 + # compute log-sum-exp in base 2 units + mi_base2 = m_i * RCP_LN2 + softmax_lse = mi_base2 + tl.math.log2(l_i) + # convert back to natural units + softmax_lse *= LN2 + + if IS_CAUSAL: + # zero out nans caused by -infs when doing causal + lse_causal_mask = (start_m_idx + tl.arange(0, BLOCK_M)) < causal_start_idx + softmax_lse = tl.where(lse_causal_mask, 0.0, softmax_lse) + + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. + # This is only true for the last M block. For others, overflow_size will be -ve + offs_lse = off_z * stride_lse_z + off_q_head * stride_lse_h + cu_seqlens_q_start * stride_lse_m + offs_m*stride_lse_m + if overflow_size > 0: + boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) + lse_mask = tl.arange(0, BLOCK_M) < boundary + tl.store(softmax_lse_ptr + offs_lse, softmax_lse, mask=lse_mask) # the log of the normalization constant + else: + tl.store(softmax_lse_ptr + offs_lse, softmax_lse) # the log of the normalization constant + + # write back O + offs_out = (off_z * stride_oz + + off_q_head * stride_oh + + cu_seqlens_q_start * stride_om + + offs_m[:, None] * stride_om + + offs_d[None, :] * stride_on) + out_mask = tl.full([BLOCK_M, BLOCK_DMODEL_POW2], 1, dtype=tl.int1) + if overflow_size > 0: + out_mask = out_mask & (offs_m[:, None] < seqlen_q) + if BLOCK_DMODEL != BLOCK_DMODEL_POW2: + out_mask = out_mask & (offs_d[None, :] < BLOCK_DMODEL) + op = acc.to(out_ptr.dtype.element_ty) + tl.store(out_ptr + offs_out, op, mask=out_mask) + +def _flash_attn_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float, + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + alibi_slopes: Optional[torch.Tensor], + return_lse: bool, + return_softmax: bool, + max_seqlen_q: int, + max_seqlen_k: int, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + #FP8 + IS_FP8 = is_fp8(q) + FP8_MAX: tl.constexpr=torch.finfo(q.dtype).max + is_varlen = True if cu_seqlens_q is not None else False + + if IS_FP8: + o = torch.zeros_like(q, dtype=torch.float32) + else: + o = torch.zeros_like(q) + if is_varlen: + #Layout for q,k,v is thd ie [total_tokens, num_head, head_dim] + batch, seqlen_q, num_q_heads, head_sz = len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], q.shape[2] + seqlen_k, num_k_heads = max_seqlen_k, k.shape[1] + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + else: + #Layout for q,k,v is bshd ie [batch, seq_len, num_head, head_dim] + batch, seqlen_q, num_q_heads, head_sz = q.shape + seqlen_k = k.shape[1] + num_k_heads = k.shape[2] + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + + #padding for head_dim. Power of 2 or 16 + BLOCK_DMODEL_POW2 = triton.next_power_of_2(head_sz) + BLOCK_DMODEL_POW2 = max(BLOCK_DMODEL_POW2, 16) + + #softmax_lse [batch, num_q_heads, seqlen_q] + if return_lse: + if is_varlen: + softmax_lse = torch.zeros((q.shape[0], num_q_heads), device=q.device, dtype=torch.float32) + stride_lse_z, stride_lse_h, stride_lse_m = 0, softmax_lse.stride(1), softmax_lse.stride(0) + else: + softmax_lse = torch.zeros((batch, num_q_heads, max_seqlen_q), device=q.device, dtype=torch.float32) + stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() + else: + softmax_lse = None + + #exp_scores [batch, num_q_heads, seqlen_q, seqlen_k] + enable_dropout = dropout_p > 0.0 + if enable_dropout: + philox_seed = torch.randint(0, 0xffffff, (1,))[0].item() #No specific reason to restrict range to 0xffffff + philox_offset = torch.randint(0, 0xffffff, (1,))[0].item() #Pass in an int, not Tensor + else: + philox_seed = 0 + philox_offset = 0 + if return_softmax or enable_dropout: + s_dmask = torch.zeros((batch, num_q_heads, max_seqlen_q, max_seqlen_k), device=q.device, dtype=torch.float32) + dropout_mask = torch.zeros((batch, num_q_heads, max_seqlen_q, max_seqlen_k), device=q.device, dtype=torch.float32) + else: + s_dmask = None + dropout_mask = None + + + # Best config from ROCm/triton/python/perf-kernels/flash_attention.py::attn_fwd autotuning is BLOCK_M: 128, BLOCK_N: 64, waves_per_eu: 2, num_warps: 4, num_ctas: 1, num_stages: 1 + # Tuned for MI300x + config = { + 'BLOCK_M': 128, + 'BLOCK_N': 32, # BLOCK_N: 64 spills for _attn_fwd + 'waves_per_eu': 2, + 'num_warps': 4, + 'num_ctas': 1, + 'num_stages': 1, + } + + grid = lambda META:(triton.cdiv(seqlen_q, META['BLOCK_M']), num_q_heads, batch) + _attn_fwd[grid](q, + k, + v, + descale_q, + descale_k, + descale_v, + o, + alibi_slopes, + s_dmask, + dropout_mask, + softmax_lse, + *q_strides, + *k_strides, + *v_strides, + descale_q.stride(0) if descale_q is not None else 0, + descale_k.stride(0) if descale_k is not None else 0, + descale_v.stride(0) if descale_v is not None else 0, + *o_strides, + alibi_slopes.stride(0) if alibi_slopes is not None else 0, + alibi_slopes.stride(1) if alibi_slopes is not None else 0, + s_dmask.stride(0) if s_dmask is not None else 0, + s_dmask.stride(1) if s_dmask is not None else 0, + s_dmask.stride(2) if s_dmask is not None else 0, + s_dmask.stride(3) if s_dmask is not None else 0, + stride_lse_z if softmax_lse is not None else 0, + stride_lse_h if softmax_lse is not None else 0, + stride_lse_m if softmax_lse is not None else 0, + softmax_scale, + cu_seqlens_q, + cu_seqlens_k, + dropout_p, + philox_seed, + philox_offset, + SEQLEN_Q=max_seqlen_q, + SEQLEN_K=max_seqlen_k, + IS_CAUSAL=causal, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BLOCK_DMODEL=head_sz, + BLOCK_DMODEL_POW2=BLOCK_DMODEL_POW2, + RETURN_SCORES=return_softmax, + ENABLE_DROPOUT=enable_dropout, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + VARLEN=is_varlen, + **config + ) + + return o, softmax_lse, s_dmask, philox_seed, philox_offset + +# This function computes delta given output Out and gradient DO +# Here is the I/O shape: +# Out: (batch, nhead_q, max_seqlens_q, headDim) +# DO: (batch, nhead_q, max_seqlens_q, headDim) +# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at +@triton.jit +def _bwd_preprocess( + o_ptr, do_ptr, # noqa: E741 + delta_ptr, + stride_o_b, stride_o_h, stride_o_m, stride_o_k, + stride_delta_b, stride_delta_h, stride_delta_m, + stride_descale_do_z, + cu_seqlens_q, max_seqlen_q, + descale_do_ptr, + BLOCK_M: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr +): + pid_m = tl.program_id(0) #seqlen + bid = tl.program_id(1) #batch + hid = tl.program_id(2) #head + + # Handle varlen + q_start = 0 + seqlen_q = max_seqlen_q + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + seqlen_q = q_end - q_start + else: + q_start = 0 + seqlen_q = max_seqlen_q + + # Compute offsets + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + + # Offset O/DO by batch, head and q_start + offs = (bid * stride_o_b + + hid * stride_o_h + + q_start * stride_o_m + offs_m[:, None] * stride_o_m + + offs_k[None, :] * stride_o_k) + + # create masks + mask_m = offs_m < seqlen_q + mask = mask_m[:, None] + PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) + if PADDED_HEAD: + mask &= offs_k[None, :] < BLOCK_D_MODEL + + # load [BLOCK_M, BLOCK_D_MODEL_POW2] + o = tl.load(o_ptr + offs, mask=mask, other=0.0) + do = tl.load(do_ptr + offs, mask=mask, other=0.0) + + # compute and write-back to delta + if IS_FP8: + descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hid) + + # NOTE: do is in the fp8 range and o is not in fp8 + delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) + else: + delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) + + offs_delta = (bid * stride_delta_b + + hid * stride_delta_h + + q_start * stride_delta_m + offs_m * stride_delta_m) + tl.store(delta_ptr + offs_delta, delta, mask=mask_m) + +@triton.jit +def _bwd_dq_inner( + dq, + q, K, V, do, m, Delta, sm_scale, + stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, + stride_dropout_m, stride_dropout_n, + stride_deltam, + seqlen_q, seqlen_k, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + start_m, start_n, end_n, num_steps, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + RCP_LN2: tl.constexpr = 1.4426950408889634 + + PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + + # mask to make sure not OOB of seqlen_q + mask_m = offs_m < seqlen_q + + kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk + vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk + + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0) + + curr_n = start_n + step_n = BLOCK_N + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + for blk_idx in range(num_steps): + offs_n = curr_n + tl.arange(0, BLOCK_N) + # end_n is needed because the end of causal True might not be perfectly + # aligned with the end of the block + mask_n = offs_n < end_n + mask_kT = mask_n[None, :] + mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) + if PADDED_HEAD: + mask_kT &= offs_k[:, None] < BLOCK_D_MODEL + + kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) + vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) + + #dropout + if ENABLE_DROPOUT: + philox_offs = (curr_philox_offset + + offs_m[:, None] * stride_dropout_m + + offs_n[None, :] * stride_dropout_n) + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1 / (1 - dropout_p) + + #qk + if IS_FP8: + qk = tl.dot(q, kT) * descale_q * descale_k + else: + qk = tl.dot(q, kT) + p = tl.math.exp2(qk * sm_scale * RCP_LN2 - m * RCP_LN2) + + if MASK: + causal_mask = (offs_m[:, None] - delta_qk) >= offs_n[None, :] + mask = causal_mask * mask_mn + p = tl.where(mask, p, 0.0) + + #dp + if IS_FP8: + dp = (tl.dot(do, vT) * descale_do * descale_v) + else: + dp = tl.dot(do, vT) + + if ENABLE_DROPOUT: + dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale + + #ds + delta_i = Di[:, None] + ds = p * (dp - delta_i) + + #dq + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + if IS_FP8: + scale_ds, descale_ds = compute_fp8_scaling_factors(ds, FP8_MAX) + dq += (tl.dot((ds*scale_ds).to(kT.type.element_ty), tl.trans(kT)) * descale_ds * descale_k) + else: + dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) + + curr_n += step_n + kT_ptrs += step_n * stride_kn + vT_ptrs += step_n * stride_vn + return dq + + +@triton.jit +def _bwd_dkdv_inner( + dk, dv, + Q, k, v, DO, M, D, sm_scale, + stride_q_m, stride_q_k, + stride_do_m, stride_do_k, + stride_dropout_m, stride_dropout_n, + stride_deltam, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + seqlen_q, seqlen_k, + start_n, start_m, num_steps, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + + # mask to make sure not OOB of seqlen_q + mask_n = offs_n < seqlen_k + qT_ptrs = Q + offs_m[None, :] * stride_q_m + offs_k[:, None] * stride_q_k #[BLOCK_D_MODEL_POW2, BLOCK_M] + do_ptrs = DO + offs_m[:, None] * stride_do_m + offs_k[None,: ] * stride_do_k + curr_m = start_m + step_m = BLOCK_M + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 + + #Iterate over blocks(BLOCK_M size) of Q while calculating + #a fixed block(BLOCK_N) of dk and dv. Note, during backward + #pass P has to be recomputed. However, this kernel computes + #dV and dK, so we compute we need P^T and S^T. See backward pass + #equations + # + #From Flash Attention Paper: + #ForwardPass: S = QkT, P=softmax(S), O=PV + # + #BackwardPass equations + #dV = P^TdO + #dP = dOV^T + #dS = dsoftmax(dP) + #dQ = dSK + #dK = QdS^T + for blk_idx in range(num_steps): + offs_m = curr_m + tl.arange(0, BLOCK_M) + mask_m = offs_m < seqlen_q + mask_qT = mask_m[None, :] + mask_do = mask_m[:, None] + mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) + if PADDED_HEAD: + mask_qT &= offs_k[:, None] < BLOCK_D_MODEL + mask_do &= offs_k[None, :] < BLOCK_D_MODEL + + #load qT + qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + + #dropout + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = (curr_philox_offset + + offs_m[None, :] * stride_dropout_m + + offs_n[:, None] * stride_dropout_n) + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1.0 / (1 - dropout_p) + + #Load M + m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) + + #Compute qkT + if IS_FP8: + qkT = (tl.dot(k, qT) * descale_q * descale_k) + else: + qkT = tl.dot(k, qT) + + #Compute pT(use m and also apply sm_scale) + pT = tl.math.exp(qkT * sm_scale - m[None, :]) + + if MASK: + causal_mask = (offs_m[None, :] - delta_qk) >= offs_n[:, None] + mask = causal_mask & mask_nm + pT = tl.where(mask, pT, 0.0) + + #load DO + do = tl.load(do_ptrs, mask=mask_do, other=0.0) + + #dV + if ENABLE_DROPOUT: + pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + if IS_FP8: + scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors(pT_dropout, FP8_MAX) + dv += (tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do) * descale_p_dropout * descale_do) + else: + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + else: + if IS_FP8: + scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) + dv += (tl.dot((pT * scale_pT).to(do.type.element_ty), do) * descale_pT * descale_do) + else: + dv += tl.dot(pT.to(do.type.element_ty), do) + + #Load delta + Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) + + #Compute dP and dS + if IS_FP8: + dpT = tl.dot(v, tl.trans(do)) * descale_v * descale_do + else: + dpT = tl.dot(v, tl.trans(do)) + + if ENABLE_DROPOUT: + dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + + delta_i = Di[None, :] + dsT = pT * (dpT - delta_i) + + #compute dk + if IS_FP8: + scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) + dk += (tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) * descale_dsT * descale_q) + else: + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + + #increment pointers + curr_m += step_m + qT_ptrs += step_m * stride_q_m + do_ptrs += step_m * stride_do_m + + return dk, dv + + +@triton.jit +def _bwd_dkdvdq_inner( + dk, dv, + Q, k, v, DO, DQ, M, D, sm_scale, + stride_q_m, stride_q_k, + stride_do_m, stride_do_k, + stride_dropout_m, stride_dropout_n, + stride_deltam, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + seqlen_q, seqlen_k, + start_n, start_m, num_steps, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + workgroup_id: tl.int32, +): + PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + + # mask to make sure not OOB of seqlen_q + mask_n = offs_n < seqlen_k + + qT_ptrs_start = Q + offs_m[None, :] * stride_q_m + offs_k[:, None] * stride_q_k #[BLOCK_D_MODEL_POW2, BLOCK_M] + dq_ptrs_start = DQ + offs_m[:, None] * stride_q_m + offs_k[None,:] * stride_q_k #[BLOCK_M, BLOCK_D_MODEL_POW2] + + do_ptrs_start = DO + offs_m[:, None] * stride_do_m + offs_k[None,: ] * stride_do_k + curr_m = start_m + step_m = BLOCK_M + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 + + #Iterate over blocks(BLOCK_M size) of Q while calculating + #a fixed block(BLOCK_N) of dk and dv. Note, during backward + #pass P has to be recomputed. However, this kernel computes + #dV and dK, so we compute we need P^T and S^T. See backward pass + #equations + # + #From Flash Attention Paper: + #ForwardPass: S = QkT, P=softmax(S), O=PV + # + #BackwardPass equations + #dV = P^TdO + #dP = dOV^T + #dS = dsoftmax(dP) + #dQ = dSK + #dK = QdS^T + + # Compute a starting index and step based on workgroup_id + # Use a simple hash-like function to spread out the starting points + start_idx = (workgroup_id * 17) % num_steps # 17 is an arbitrary prime to spread indices + # Ensure step is coprime with num_steps to visit all indices exactly once + step = 1 # 3 if num_steps > 1 or num_steps==3 else 1 # coprime with num_steps + + + for iter in range(num_steps): + # Compute the permuted block index + blk_idx = (start_idx + iter * step) % num_steps + + curr_m = start_m + blk_idx * step_m + qT_ptrs = qT_ptrs_start + blk_idx * step_m * stride_q_m + dq_ptrs = dq_ptrs_start + blk_idx * step_m * stride_q_m + do_ptrs = do_ptrs_start + blk_idx * step_m * stride_do_m + + offs_m = curr_m + tl.arange(0, BLOCK_M) + mask_m = offs_m < seqlen_q + mask_qT = mask_m[None, :] + mask_do = mask_m[:, None] + mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) + + if PADDED_HEAD: + mask_qT &= offs_k[:, None] < BLOCK_D_MODEL + mask_do &= offs_k[None, :] < BLOCK_D_MODEL + + #load qT + qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + + #dropout + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = (curr_philox_offset + + offs_m[None, :] * stride_dropout_m + + offs_n[:, None] * stride_dropout_n) + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1.0 / (1 - dropout_p) + + #Load M + m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) + + #Compute qkT + if IS_FP8: + qkT = (tl.dot(k, qT) * descale_q * descale_k) + else: + qkT = tl.dot(k, qT) + + #Compute pT(use m and also apply sm_scale) + pT = tl.math.exp(qkT * sm_scale - m[None, :]) + + if MASK: + causal_mask = (offs_m[None, :] - delta_qk) >= (offs_n[:, None]) + mask = causal_mask & mask_nm + pT = tl.where(mask, pT, 0.0) + + #load DO + do = tl.load(do_ptrs, mask=mask_do, other=0.0) + + #dV + if ENABLE_DROPOUT: + pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + if IS_FP8: + scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors(pT_dropout, FP8_MAX) + dv += (tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do) * descale_p_dropout * descale_do) + else: + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + else: + if IS_FP8: + scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) + dv += (tl.dot((pT * scale_pT).to(do.type.element_ty), do) * descale_pT * descale_do) + else: + dv += tl.dot(pT.to(do.type.element_ty), do) + + #Load delta + Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) + + #Compute dP and dS + if IS_FP8: + dpT = tl.dot(v, tl.trans(do)) * descale_v * descale_do + else: + dpT = tl.dot(v, tl.trans(do)) + + if ENABLE_DROPOUT: + dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + + delta_i = Di[None, :] + dsT = pT * (dpT - delta_i) + + #compute dk + if IS_FP8: + scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) + dk += (tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) * descale_dsT * descale_q) + else: + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + + + # We can compute the dq_partial here and do a atomic add to the correct memory location + # NOTE: Possible problems with the atomic add: contention, is inside a loop which has achieved bad perf before + # (BLOCK_M, BLOCK_N) x (BLOCK_N, D) + if IS_FP8: + dq_partial = tl.dot((dsT * scale_dsT).to(k.dtype).T, k) * descale_dsT * descale_k + else: + dq_partial = tl.dot(dsT.to(k.dtype).T, k) + tl.atomic_add( + dq_ptrs, + dq_partial * sm_scale, + mask=mask_m[:, None], + sem="relaxed", + ) + + return dk, dv + + +@triton.jit +def _bwd_kernel_dkdvdq_causal( + q_ptr, k_ptr, v_ptr, sm_scale, do_ptr, dk_ptr, dv_ptr, dq_ptr, + m_ptr, delta_ptr, + stride_q_b, stride_q_h, stride_q_m, stride_q_k, + stride_k_b, stride_k_h, stride_k_n, stride_k_k, + stride_v_b, stride_v_h, stride_v_n, stride_v_k, + stride_dk_b, stride_dk_h, stride_dk_n, stride_dk_k, + stride_delta_b, stride_delta_h, stride_delta_m, + stride_do_b, stride_do_h, stride_do_m, stride_do_k, + stride_dropout_b, stride_dropout_h, stride_dropout_m, stride_dropout_n, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset_base, + descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BATCH, + NUM_K_PIDS, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + wid = tl.program_id(0) # workgoup id: 0, ..., NUM_K_PIDS * BATCH * NUM_K_HEADS - 1 + + # workgroups get launched first along batch dim, then in head_k dim, and then in seq k block dim + batch_idx = wid % BATCH + head_k_idx = wid // BATCH % NUM_K_HEADS + seq_k_blk_idx = wid // (BATCH * NUM_K_HEADS) % NUM_K_PIDS + + #Determine q and k start along with seqlen_q and seqlen_k + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + batch_idx) + q_end = tl.load(cu_seqlens_q + batch_idx + 1) + k_start = tl.load(cu_seqlens_k + batch_idx) + k_end = tl.load(cu_seqlens_k + batch_idx + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + + # Figure out causal starting block since we have seqlen_q >=< seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + delta_qk = seqlen_q - seqlen_k + + # q > k: diretcly skip all the way until the start of causal block + start_delta_q_gt_k = delta_qk + + # q < k: some blocks will have no Masked block, other needs to re-calc + # starting position + # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the + # masked op + num_blocks_skip = -delta_qk // BLOCK_N + delta_aligned = (num_blocks_skip + 1) * BLOCK_N + delta_qk + start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M + if delta_qk >= 0: + start_delta = delta_qk + else: + start_delta = start_delta_q_lt_k + + start_n = seq_k_blk_idx * BLOCK_N + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_n = start_n + tl.arange(0, BLOCK_N) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) + if PADDED_HEAD: + mask_k = offs_k < BLOCK_D_MODEL + mask_kv &= mask_k[None, :] + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + adj_k = (batch_idx * stride_k_b + + head_k_idx * stride_k_h + + k_start * stride_k_n + offs_n[:, None] * stride_k_n + + offs_k[None, :] * stride_k_k) + adj_v = (batch_idx * stride_v_b + + head_k_idx * stride_v_h + + k_start * stride_v_n + offs_n[:, None] * stride_v_n + + offs_k[None, :] * stride_v_k) + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(k_ptr + adj_k , mask=mask_kv, other=0.0) + v = tl.load(v_ptr + adj_v, mask=mask_kv, other=0.0) + + # If MQA / GQA, set the K and V head offsets appropriately. + for head_q_idx in range(head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE): + if delta_qk >= 0: + start_m = start_n + start_delta + len_m = BLOCK_N + else: + start_m = max(start_n + delta_qk, 0) + start_m = (start_m // BLOCK_M) * BLOCK_M + # because we might shift the masked blocks up, we are deeper into + # the masked out region, so we would potentially increase the total + # steps with masked operation to get out of it + residue_m = max(start_n + delta_qk - start_m, 0) + len_m = BLOCK_N + residue_m + + # offset input and output tensor by batch and Q/K heads + adj_q = batch_idx * stride_q_b + head_q_idx * stride_q_h + q_start * stride_q_m + + q_ptr_adj = q_ptr + adj_q + dq_ptr_adj = dq_ptr + adj_q + + adj_do = batch_idx * stride_do_b + head_q_idx * stride_do_h + q_start * stride_do_m + do_ptr_adj = do_ptr + adj_do + adj_delta = batch_idx * stride_delta_b + head_q_idx * stride_delta_h + q_start * stride_delta_m + m_ptr_adj = m_ptr + adj_delta + delta_ptr_adj = delta_ptr + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = (philox_offset_base + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h) + dropout_offset = (dropout_mask + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h) + + MASK_BLOCK_M: tl.constexpr = BLOCK_M // BLK_SLICE_FACTOR + # bound the masked operation to q len so it does not have to wast cycles + len_m = min(len_m, seqlen_q) + num_steps = tl.cdiv(len_m, MASK_BLOCK_M) + + + # when q < k, we may skip the initial masked op + # if seq_k_blk_idx < num_blocks_skip: + # num_steps = 0 + + if IS_FP8: + descale_q = tl.load(descale_q_ptr + batch_idx * stride_descale_q_z + head_q_idx) + descale_k = tl.load(descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx) + descale_v = tl.load(descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx) + descale_do = tl.load(descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + # if start_m is negative, the current N-tile has no block on the + # diagonal of causal mask, so everything have no causal mask + dk, dv = _bwd_dkdvdq_inner( + dk, dv, # output tensors + q_ptr_adj, k, v, do_ptr_adj, dq_ptr_adj, m_ptr_adj, delta_ptr_adj, sm_scale, # input tensors + stride_q_m, stride_q_k, # strides for q + stride_do_m, stride_do_k, # strides for o + stride_dropout_m, stride_dropout_n, # strides for dropout + stride_delta_m, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user + MASK_BLOCK_M, BLOCK_N, # block dim + BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, # head dim + MASK=True, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + workgroup_id=seq_k_blk_idx, + ) + start_m += num_steps * MASK_BLOCK_M + num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) + end_m = start_m + num_steps * BLOCK_M + + dk, dv = _bwd_dkdvdq_inner( + dk, dv, # output tensors + q_ptr_adj, k, v, do_ptr_adj, dq_ptr_adj, m_ptr_adj, delta_ptr_adj, sm_scale, # input tensors + stride_q_m, stride_q_k, # strides for q + stride_do_m, stride_do_k, # strides for o + stride_dropout_m, stride_dropout_n, # strides for dropout + stride_delta_m, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user + BLOCK_M, BLOCK_N, # block dim + BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, # head dim + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + workgroup_id=seq_k_blk_idx, + ) + + # Write back dV and dK. + offs_dkdv = (batch_idx * stride_dk_b + + head_k_idx * stride_dk_h + + k_start * stride_dk_n + offs_n[:, None] * stride_dk_n + + offs_k[None, :] * stride_dk_k) + tl.store(dv_ptr + offs_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(dk_ptr + offs_dkdv, dk, mask=mask_kv) + + +@triton.jit +def _bwd_kernel_dkdv_causal( + q_ptr, k_ptr, v_ptr, sm_scale, do_ptr, dk_ptr, dv_ptr, + m_ptr, delta_ptr, + stride_q_b, stride_q_h, stride_q_m, stride_q_k, + stride_k_b, stride_k_h, stride_k_n, stride_k_k, + stride_v_b, stride_v_h, stride_v_n, stride_v_k, + stride_dk_b, stride_dk_h, stride_dk_n, stride_dk_k, + stride_delta_b, stride_delta_h, stride_delta_m, + stride_do_b, stride_do_h, stride_do_m, stride_do_k, + stride_dropout_b, stride_dropout_h, stride_dropout_m, stride_dropout_n, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset_base, + descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + #seq block, batch, head_k + seq_k_blk_idx = tl.program_id(0) + batch_idx = tl.program_id(1) + head_k_idx = tl.program_id(2) + + #Determine q and k start along with seqlen_q and seqlen_k + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + batch_idx) + q_end = tl.load(cu_seqlens_q + batch_idx + 1) + k_start = tl.load(cu_seqlens_k + batch_idx) + k_end = tl.load(cu_seqlens_k + batch_idx + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + + # Figure out causal starting block since we have seqlen_q >=< seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + delta_qk = seqlen_q - seqlen_k + + # q > k: diretcly skip all the way until the start of causal block + start_delta_q_gt_k = delta_qk + + # q < k: some blocks will have no Masked block, other needs to re-calc + # starting position + # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the + # masked op + num_blocks_skip = -delta_qk // BLOCK_N + delta_aligned = (num_blocks_skip + 1) * BLOCK_N + delta_qk + start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M + if delta_qk >= 0: + start_delta = delta_qk + else: + start_delta = start_delta_q_lt_k + + start_n = seq_k_blk_idx *BLOCK_N + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_n = start_n + tl.arange(0, BLOCK_N) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) + if PADDED_HEAD: + mask_k = offs_k < BLOCK_D_MODEL + mask_kv &= mask_k[None, :] + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + adj_k = (batch_idx * stride_k_b + + head_k_idx * stride_k_h + + k_start * stride_k_n + offs_n[:, None] * stride_k_n + + offs_k[None, :] * stride_k_k) + adj_v = (batch_idx * stride_v_b + + head_k_idx * stride_v_h + + k_start * stride_v_n + offs_n[:, None] * stride_v_n + + offs_k[None, :] * stride_v_k) + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(k_ptr + adj_k , mask=mask_kv, other=0.0) + v = tl.load(v_ptr + adj_v, mask=mask_kv, other=0.0) + + # If MQA / GQA, set the K and V head offsets appropriately. + for head_q_idx in range(head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE): + if delta_qk >= 0: + start_m = start_n + start_delta + len_m = BLOCK_N + else: + start_m = max(start_n + delta_qk, 0) + start_m = start_m // BLOCK_M * BLOCK_M + # because we might shift the masked blocks up, we are deeper into + # the masked out region, so we would potentially increase the total + # steps with masked operation to get out of it + residue_m = max(start_n + delta_qk - start_m, 0) + len_m = BLOCK_N + residue_m + + # offset input and output tensor by batch and Q/K heads + adj_q = batch_idx * stride_q_b + head_q_idx * stride_q_h + q_start * stride_q_m + q_ptr_adj = q_ptr + adj_q + adj_do = batch_idx * stride_do_b + head_q_idx * stride_do_h + q_start * stride_do_m + do_ptr_adj = do_ptr + adj_do + adj_delta = batch_idx * stride_delta_b + head_q_idx * stride_delta_h + q_start * stride_delta_m + m_ptr_adj = m_ptr + adj_delta + delta_ptr_adj = delta_ptr + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = (philox_offset_base + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h) + dropout_offset = (dropout_mask + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h) + + MASK_BLOCK_M: tl.constexpr = BLOCK_M // BLK_SLICE_FACTOR + # bound the masked operation to q len so it does not have to wast cycles + len_m = min(len_m, seqlen_q) + num_steps = tl.cdiv(len_m, MASK_BLOCK_M) + # when q < k, we may skip the initial masked op + if seq_k_blk_idx < num_blocks_skip: + num_steps = 0 + + if IS_FP8: + descale_q = tl.load(descale_q_ptr + batch_idx * stride_descale_q_z + head_q_idx) + descale_k = tl.load(descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx) + descale_v = tl.load(descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx) + descale_do = tl.load(descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + # if start_m is negative, the current N-tile has no block on the + # diagonal of causal mask, so everything have no causal mask + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + q_ptr_adj, k, v, do_ptr_adj, m_ptr_adj, delta_ptr_adj, sm_scale, # input tensors + stride_q_m, stride_q_k, # strides for q + stride_do_m, stride_do_k, # strides for o + stride_dropout_m, stride_dropout_n, # strides for dropout + stride_delta_m, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user + MASK_BLOCK_M, BLOCK_N, # block dim + BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, # head dim + MASK=True, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + start_m += num_steps * MASK_BLOCK_M + num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) + end_m = start_m + num_steps * BLOCK_M + + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + q_ptr_adj, k, v, do_ptr_adj, m_ptr_adj, delta_ptr_adj, sm_scale, # input tensors + stride_q_m, stride_q_k, # strides for q + stride_do_m, stride_do_k, # strides for o + stride_dropout_m, stride_dropout_n, # strides for dropout + stride_delta_m, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user + BLOCK_M, BLOCK_N, # block dim + BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, # head dim + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + + # Write back dV and dK. + offs_dkdv = (batch_idx * stride_dk_b + + head_k_idx * stride_dk_h + + k_start * stride_dk_n + offs_n[:, None] * stride_dk_n + + offs_k[None, :] * stride_dk_k) + tl.store(dv_ptr + offs_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(dk_ptr + offs_dkdv, dk, mask=mask_kv) + +@triton.jit +def _bwd_kernel_dq_causal( + q_ptr, k_ptr, v_ptr, sm_scale, do_ptr, dq_ptr, + m_ptr, delta_ptr, + stride_q_b, stride_q_h, stride_q_m, stride_q_k, + stride_k_b, stride_k_h, stride_k_n, stride_k_k, + stride_v_b, stride_v_h, stride_v_n, stride_v_k, + stride_dq_b, stride_dq_h, stride_dq_m, stride_dq_k, + stride_delta_b, stride_delta_h, stride_delta_m, + stride_do_b, stride_do_h, stride_do_m, stride_do_k, + stride_dropout_b, stride_dropout_h, stride_dropout_m, stride_dropout_n, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset_base, + descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + seq_q_blk_idx = tl.program_id(0) + batch_idx = tl.program_id(1) + head_k_idx = tl.program_id(2) + + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + batch_idx) + q_end = tl.load(cu_seqlens_q + batch_idx + 1) + k_start = tl.load(cu_seqlens_k + batch_idx) + k_end = tl.load(cu_seqlens_k + batch_idx + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + # Figure out causal starting block since we have seqlen_q <=> seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + # DQ tiles on M dim and iterate on N dim, so we there could be some tiles we + # can simply skip and we need to adjust starting position. + start_m = seq_q_blk_idx * BLOCK_M + # seqlen_q > seqlen_k, no need to process these tile for dq + delta_qk = seqlen_q - seqlen_k + if start_m + BLOCK_M < delta_qk: + return + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_m = start_m + tl.arange(0, BLOCK_M) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) + if PADDED_HEAD: + mask_k = offs_k < BLOCK_D_MODEL + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_q_m + offs_k[None, :] * stride_q_k + offs_do = offs_m[:, None] * stride_do_m + offs_k[None, :] * stride_do_k + adj_k = batch_idx * stride_k_b + head_k_idx * stride_k_h + k_start * stride_k_n + adj_v = batch_idx * stride_v_b + head_k_idx * stride_v_h + k_start * stride_v_n + k_ptr_adj = k_ptr + v_ptr_adj = v_ptr + k_ptr_adj += adj_k + v_ptr_adj += adj_v + + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + for head_q_idx in range(head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE): + # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front + # for every M-tile + end_n = start_m + BLOCK_M - delta_qk + # clamp end_n at [0, seqlen_k] + end_n = max(min(end_n, seqlen_k), 0) + + # offset input and output tensor by batch and Q/K heads + adj_q = (batch_idx * stride_q_b + + head_q_idx * stride_q_h + + q_start * stride_q_m) + adj_do = (batch_idx * stride_do_b + + head_q_idx * stride_do_h + + q_start * stride_do_m) + adj_delta = (batch_idx * stride_delta_b + + head_q_idx * stride_delta_h + + q_start * stride_delta_m) + delta_ptr_adj = delta_ptr + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = (philox_offset_base + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h) + dropout_offset = (dropout_mask + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h) + + q = tl.load(q_ptr + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(do_ptr + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(m_ptr + adj_delta + offs_m * stride_delta_m, + mask=offs_m < seqlen_q) + m = m[:, None] + + MASK_BLOCK_N: tl.constexpr = BLOCK_N // BLK_SLICE_FACTOR + # start can only be 0 at minimum + start_n = max(end_n - BLOCK_M, 0) + num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N) + + if IS_FP8: + descale_q = tl.load(descale_q_ptr + batch_idx * stride_descale_q_z + head_q_idx) + descale_k = tl.load(descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx) + descale_v = tl.load(descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx) + descale_do = tl.load(descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + dq = tl.zeros([BLOCK_M, BLOCK_D_MODEL_POW2], dtype=tl.float32) + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _bwd_dq_inner, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + dq = _bwd_dq_inner( + dq, + q, k_ptr_adj, v_ptr_adj, do, m, delta_ptr_adj, sm_scale, + stride_q_m, stride_q_k, stride_k_n, stride_k_k, stride_v_n, stride_v_k, + stride_dropout_m, stride_dropout_n, + stride_delta_m, + seqlen_q, seqlen_k, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + start_m, start_n, end_n, num_steps, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M, MASK_BLOCK_N, + BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, + MASK=True, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + end_n -= num_steps * MASK_BLOCK_N + num_steps = tl.cdiv(end_n, BLOCK_N) + start_n = max(end_n - num_steps * BLOCK_N, 0) + dq = _bwd_dq_inner( + dq, + q, k_ptr_adj, v_ptr_adj, do, m, delta_ptr_adj, sm_scale, + stride_q_m, stride_q_k, stride_k_n, stride_k_k, stride_v_n, stride_v_k, + stride_dropout_m, stride_dropout_n, + stride_delta_m, + seqlen_q, seqlen_k, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + start_m, start_n, end_n, num_steps, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M, BLOCK_N, + BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + # Write back dQ. + offs_dq = (batch_idx * stride_dq_b + + head_q_idx * stride_dq_h + + q_start * stride_dq_m + + offs_m[:, None] * stride_dq_m + + offs_k[None, :] * stride_dq_k) + dq *= sm_scale + tl.store(dq_ptr + offs_dq, dq, mask=mask_q) + + +@triton.jit +def _bwd_kernel_dkdvdq_noncausal( + Q, K, V, sm_scale, DO, DK, DV, DQ, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset, + descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BATCH, + NUM_K_PIDS, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + # workgroup id + wid = tl.program_id(0) # 0, ..., NUM_K_PIDS * BATCH * NUM_K_HEADS - 1 + + # Workgroups get launched first along batch dim, then in head_k dim, and then in seq k block dim + # This is in order to avoid contention for the tl.atomic_add (inside _bwd_dkdvdq_inner) that happens between workgroups that share the same batch and head_k. + bid = wid % BATCH + hkid = wid // BATCH % NUM_K_HEADS + pid = wid // (BATCH * NUM_K_HEADS) % NUM_K_PIDS + + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + + dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + + start_n = pid * BLOCK_N + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) + if PADDED_HEAD: + mask_kv &= offs_k < BLOCK_D_MODEL + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + adj_k = (bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_k[None, :] * stride_kk) + adj_v = (bid * stride_vb + + hkid * stride_vh + + k_start * stride_vn + + offs_n[:, None] * stride_vn + + offs_k[None, :] * stride_vk) + + k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + adj_q = (bid * stride_qb + hqid * stride_qh + q_start * stride_qm) + + Q_ptr = Q + adj_q + DQ_ptr = DQ + adj_q + + adj_do = (bid * stride_dob + hqid * stride_doh + q_start * stride_dom) + DO_ptr = DO + adj_do + adj_delta = (bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam) + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + #dropout + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = dropout_mask + bid * stride_dropoutb + \ + hqid * stride_dropouth + + if IS_FP8: + descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) + descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) + descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) + descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + start_m = 0 + num_steps = tl.cdiv(seqlen_q, BLOCK_M) + + dk, dv = _bwd_dkdvdq_inner( + dk, dv, + Q_ptr, k, v, DO_ptr, DQ_ptr, M_ptr, Delta_ptr, sm_scale, + stride_qm, stride_qk, + stride_dom, stride_dok, + stride_dropoutm, stride_dropoutn, + stride_deltam, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + seqlen_q, seqlen_k, + start_n, start_m, num_steps, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M, BLOCK_N, + BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + workgroup_id=pid, + ) + + adj_dkdv = (bid * stride_dkb + + hkid * stride_dkh + + k_start * stride_dkn + offs_n[:, None] * stride_dkn + + offs_k[None, :] * stride_dkk) + tl.store(DV + adj_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv, dk, mask=mask_kv) + + + +@triton.jit +def _bwd_kernel_dkdv_noncausal( + Q, K, V, sm_scale, DO, DK, DV, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset, + descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + + dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + + start_n = pid * BLOCK_N + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) + if PADDED_HEAD: + mask_kv &= offs_k < BLOCK_D_MODEL + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + adj_k = (bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_k[None, :] * stride_kk) + adj_v = (bid * stride_vb + + hkid * stride_vh + + k_start * stride_vn + + offs_n[:, None] * stride_vn + + offs_k[None, :] * stride_vk) + + k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + adj_q = (bid * stride_qb + hqid * stride_qh + q_start * stride_qm) + Q_ptr = Q + adj_q + adj_do = (bid * stride_dob + hqid * stride_doh + q_start * stride_dom) + DO_ptr = DO + adj_do + adj_delta = (bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam) + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + #dropout + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = dropout_mask + bid * stride_dropoutb + \ + hqid * stride_dropouth + + if IS_FP8: + descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) + descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) + descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) + descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + start_m = 0 + num_steps = tl.cdiv(seqlen_q, BLOCK_M) + dk, dv = _bwd_dkdv_inner( + dk, dv, + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, + stride_qm, stride_qk, + stride_dom, stride_dok, + stride_dropoutm, stride_dropoutn, + stride_deltam, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + seqlen_q, seqlen_k, + start_n, start_m, num_steps, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M, BLOCK_N, + BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + + adj_dkdv = (bid * stride_dkb + + hkid * stride_dkh + + k_start * stride_dkn + offs_n[:, None] * stride_dkn + + offs_k[None, :] * stride_dkk) + tl.store(DV + adj_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv, dk, mask=mask_kv) + + +@triton.jit +def _bwd_kernel_dq_noncausal( + Q, K, V, sm_scale, DO, DQ, + M, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset_base, + descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + pid = tl.program_id(0) #seqlen + bid = tl.program_id(1) #batch + hkid = tl.program_id(2) #head_k + + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + start_m = pid * BLOCK_M + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_m = start_m + tl.arange(0, BLOCK_M) + + #mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) + if PADDED_HEAD: + mask_k = offs_k < BLOCK_D_MODEL + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + K += adj_k + V += adj_v + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + delta_ptr = delta + adj_delta + + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = (philox_offset_base + + bid * stride_dropoutb + + hqid * stride_dropouth) + dropout_offset = ( + dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth) + + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=offs_m < seqlen_q) + m = m[:, None] + + #FP8 + if IS_FP8: + descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) + descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) + descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) + descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + start_n = 0 + end_n = seqlen_k + num_steps = tl.cdiv(seqlen_k, BLOCK_N) + dq = tl.zeros([BLOCK_M, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dq = _bwd_dq_inner( + dq, + q, K, V, do, m, delta_ptr, sm_scale, + stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, + stride_dropoutm, stride_dropoutn, + stride_deltam, + seqlen_q, seqlen_k, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + start_m, start_n, end_n, num_steps, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M, BLOCK_N, + BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + +def _flash_attn_backward( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + softmax_lse: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + philox_seed: Optional[int] = 0, + philox_offset: Optional[int] = 0, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_do: Optional[torch.Tensor] = None, + fused: bool = False, +): + IS_FP8 = is_fp8(q) + if IS_FP8: + FP8_MAX = torch.finfo(q.dtype).max + descale_strides = (descale_q.stride(0),descale_k.stride(0),descale_v.stride(0),descale_do.stride(0) ) + else: + FP8_MAX = None + stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_do_z = None + descale_strides = (stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z) + + IS_VARLEN = True if cu_seqlens_q is not None else False + + #get strides and shape + if IS_VARLEN: + #Layout for q,k,v is thd ie [total tokens, num_head, head_dim] + batch, seqlen_q, num_q_heads, head_sz = len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], q.shape[2] + seqlen_k, num_k_heads = max_seqlen_k, k.shape[1] + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + dq_strides = (0, dq.stride(1), dq.stride(0), dq.stride(2)) + dk_strides = (0, dk.stride(1), dk.stride(0), dk.stride(2)) + dv_strides = (0, dv.stride(1), dv.stride(0), dv.stride(2)) + do_strides = (0, do.stride(1), do.stride(0), do.stride(2)) + else: + #Layout for q,k,v is bshd ie [batch, seq_len, num_head, head_dim] + batch, seqlen_q, num_q_heads, head_sz = q.shape + seqlen_k, num_k_heads = k.shape[1], k.shape[2] + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + dq_strides = (dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3)) + dk_strides = (dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3)) + dv_strides = (dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3)) + do_strides = (do.stride(0), do.stride(2), do.stride(1), do.stride(3)) + + #BLOCK_D_MODEL, BLOCK_D_MODEL_POW2 + #padding for head_dim. Power of 2 or 16 + BLOCK_D_MODEL_POW2 = triton.next_power_of_2(head_sz) + BLOCK_D_MODEL_POW2 = max(BLOCK_D_MODEL_POW2, 16) + + #Configs + #PRE_BLOCK, BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 + #BLK_SLICE_FACTOR + NUM_WARPS, NUM_STAGES = 4, 1 + WAVES_PER_EU = 1 + PRE_BLOCK = 128 + #BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 64, 64, 64, 16 + BLK_SLICE_FACTOR = 2 + + #init delta + delta = torch.zeros_like(softmax_lse) + if IS_VARLEN: + #[total_tokens, num_q_heads, seqlen_q] + delta_strides = (0, delta.stride(1), delta.stride(0)) + else: + #[batch, num_q_heads, seqlen_q] + delta_strides = delta.stride() + + #preprocess + #compute D(delta) = rowsum(dO*O). Note, multiplication is element-wise. + pre_grid = (triton.cdiv(max_seqlen_q, PRE_BLOCK), batch, num_q_heads) + _bwd_preprocess[pre_grid]( + o, do, + delta, + *o_strides, + *delta_strides, + descale_strides[3], + cu_seqlens_q, max_seqlen_q, + descale_do, + BLOCK_M=PRE_BLOCK, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8 + ) + + #dropout_mask + use_dropout = (dropout_p > 0.0) + if use_dropout: + dropout_mask = torch.zeros( + (batch, num_q_heads, max_seqlen_q, max_seqlen_k), + device=q.device, + dtype=torch.float32) + dropout_strides = dropout_mask.stride() + else: + dropout_mask = None + dropout_strides = (0, 0, 0, 0) + + grid_dkdv = ((max_seqlen_k + BLOCK_N1 - 1) // BLOCK_N1, batch, num_k_heads) + grid_dq = ((max_seqlen_q + BLOCK_M2 - 1) // BLOCK_M2, batch, num_k_heads) + + if fused: # fuses dk, dv, dq computations into one kernel by computing the dq using atomic adds between workgroups + + BLOCK_N = 128 + config = { + "BLOCK_M": 32, + "BLOCK_N": BLOCK_N, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 1, + "BLK_SLICE_FACTOR": 2, + } + + num_k_pids = (max_seqlen_k + BLOCK_N - 1) // BLOCK_N + grid_dkdvdq = (batch * num_k_heads * num_k_pids,) + + if causal: + _bwd_kernel_dkdvdq_causal[grid_dkdvdq]( + q, k, v, sm_scale, do, dk, dv, dq, + softmax_lse, delta, + *q_strides, + *k_strides, + *v_strides, + *dk_strides, + *delta_strides, + *do_strides, + *dropout_strides, + *descale_strides, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask,dropout_p, philox_seed, philox_offset, + descale_q, descale_k, descale_v, descale_do, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BATCH=batch, + NUM_K_PIDS=num_k_pids, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + **config, + ) + else: + _bwd_kernel_dkdvdq_noncausal[grid_dkdvdq]( + q, k, v, sm_scale, do, dk, dv, dq, + softmax_lse, delta, + *q_strides, + *k_strides, + *v_strides, + *dk_strides, + *delta_strides, + *do_strides, + *dropout_strides, + *descale_strides, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask,dropout_p, philox_seed, philox_offset, + descale_q, descale_k, descale_v, descale_do, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BATCH=batch, + NUM_K_PIDS=num_k_pids, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + **config, + ) + + return delta + + # split kernels solution: one kernel computes dk, dv and the other computes dq + + if causal: + _bwd_kernel_dkdv_causal[grid_dkdv]( + q, k, v, sm_scale, do, dk, dv, + softmax_lse, delta, + *q_strides, + *k_strides, + *v_strides, + *dk_strides, + *delta_strides, + *do_strides, + *dropout_strides, + *descale_strides, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask,dropout_p, philox_seed, philox_offset, + descale_q, descale_k, descale_v, descale_do, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BLOCK_M=BLOCK_M1, + BLOCK_N=BLOCK_N1, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + _bwd_kernel_dq_causal[grid_dq]( + q, k, v, sm_scale, do, dq, + softmax_lse, delta, + *q_strides, + *k_strides, + *v_strides, + *dq_strides, + *delta_strides, + *do_strides, + *dropout_strides, + *descale_strides, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask,dropout_p, philox_seed, philox_offset, + descale_q, descale_k, descale_v, descale_do, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BLOCK_M=BLOCK_M2, + BLOCK_N=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + else: + _bwd_kernel_dkdv_noncausal[grid_dkdv]( + q, k, v, sm_scale, do, dk, dv, + softmax_lse, delta, + *q_strides, + *k_strides, + *v_strides, + *dk_strides, + *delta_strides, + *do_strides, + *dropout_strides, + *descale_strides, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask,dropout_p, philox_seed, philox_offset, + descale_q, descale_k, descale_v, descale_do, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BLOCK_M=BLOCK_M1, + BLOCK_N=BLOCK_N1, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + + _bwd_kernel_dq_noncausal[grid_dq]( + q, k, v, sm_scale, do, dq, + softmax_lse, delta, + *q_strides, + *k_strides, + *v_strides, + *dq_strides, + *delta_strides, + *do_strides, + *dropout_strides, + *descale_strides, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask,dropout_p, philox_seed, philox_offset, + descale_q, descale_k, descale_v, descale_do, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BLOCK_M=BLOCK_M2, + BLOCK_N=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + + return delta + + +class FlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_lse, + return_softmax, + is_grad_enabled, + fused_backward, + ): + is_grad = is_grad_enabled and any( + x.requires_grad for x in [q,k,v] + ) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + + head_size_og = q.size(3) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = _flash_attn_forward( + q, + k, + v, + dropout_p, + softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + alibi_slopes=alibi_slopes, + return_lse=return_lse, + return_softmax=return_softmax and dropout_p > 0, + max_seqlen_q=q.shape[1], + max_seqlen_k=k.shape[1], + ) + + if is_grad: + ctx.save_for_backward(q, k, v, out_padded, softmax_lse) + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.fused_backward = fused_backward + + + out = out_padded[..., :head_size_og] + result = [out] + if return_lse: + result.append(softmax_lse) + if return_softmax: + result.append(S_dmask) + + return tuple(result) + + @staticmethod + def backward(ctx, do, *args): + q, k, v, out, softmax_lse = ctx.saved_tensors + dq, dk, dv = torch.zeros_like(q), torch.empty_like(k), torch.empty_like(v) + head_size_v_og = do.size(3) + do_padded = do + if head_size_v_og % 8 != 0: + do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) + _flash_attn_backward( + do_padded, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + ctx.softmax_scale, + ctx.alibi_slopes, + ctx.causal, + None, + None, + max_seqlen_q=q.shape[1], + max_seqlen_k=k.shape[1], + dropout_p=ctx.dropout_p, + philox_seed=ctx.philox_seed, + philox_offset=ctx.philox_offset, + fused=ctx.fused_backward, + ) + dq = dq[..., : q.shape[-1]] # We could have padded the head dimension + dk = dk[..., : k.shape[-1]] + dv = dv[..., : v.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + +def flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1,-1), + alibi_slopes=None, + deterministic=True, + return_lse=False, + return_attn_probs=False, + fused_backward=False, +): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k: (batch_size, seqlen, nheads_k, headdim) + v: (batch_size, seqlen, nheads_k, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_lse=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_lse, + return_attn_probs, + torch.is_grad_enabled(), + fused_backward, + ) + + +class FlashAttnFP8Func(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_lse, + return_softmax, + is_grad_enabled, + ): + is_grad = is_grad_enabled and any( + x.requires_grad for x in [q,k,v] + ) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + head_size_og = q.size(3) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + + # cast input to fp8 + fp8_dtype = torch.float8_e4m3fnuz + q_fp8, descale_q = cast_to_fp8(q, fp8_dtype, "bshd") + k_fp8, descale_k = cast_to_fp8(k, fp8_dtype, "bshd") + v_fp8, descale_v = cast_to_fp8(v, fp8_dtype, "bshd") + + out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = _flash_attn_forward( + q, + k, + v, + dropout_p, + softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + alibi_slopes=alibi_slopes, + return_lse=return_lse, + return_softmax=return_softmax and dropout_p > 0, + max_seqlen_q=q.shape[1], + max_seqlen_k=k.shape[1], + cu_seqlens_q=None, + cu_seqlens_k=None, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v + ) + + if is_grad: + ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_padded, softmax_lse, descale_q, descale_k, descale_v) + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + + out = out_padded[..., :head_size_og] + result = [out] + if return_lse: + result.append(softmax_lse) + if return_softmax: + result.append(S_dmask) + + return tuple(result) + + @staticmethod + def backward(ctx, do, *args): + q_fp8, k_fp8, v_fp8, out, softmax_lse, descale_q, descale_k, descale_v = ctx.saved_tensors + dq, dk, dv = torch.zeros_like(q_fp8, dtype=torch.float32), torch.zeros_like(k_fp8, dtype=torch.float32), torch.zeros_like(v_fp8, dtype=torch.float32) + head_size_v_og = do.size(3) + do_padded = do + if head_size_v_og % 8 != 0: + do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) + + fp8_dtype = torch.float8_e4m3fnuz + do_padded_fp8, descale_do = cast_to_fp8(do_padded, fp8_dtype, "bshd") + _flash_attn_backward( + do_padded_fp8, + q_fp8, + k_fp8, + v_fp8, + out, + softmax_lse, + dq, + dk, + dv, + ctx.softmax_scale, + ctx.alibi_slopes, + ctx.causal, + None, + None, + max_seqlen_q=q_fp8.shape[1], + max_seqlen_k=k_fp8.shape[1], + dropout_p=ctx.dropout_p, + philox_seed=ctx.philox_seed, + philox_offset=ctx.philox_offset, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_do=descale_do, + ) + #dq = dq[..., : q_fp8.shape[-1]] # We could have padded the head dimension + #dk = dk[..., : k_fp8.shape[-1]] + #dv = dv[..., : v_fp8.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None + +def flash_attn_fp8_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_lse=False, + return_attn_probs=False +): + return FlashAttnFP8Func.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_lse, + return_attn_probs, + torch.is_grad_enabled() + ) + +class FlashAttnVarlenFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_lse, + return_softmax, + block_table, + is_grad_enabled, + fused_backward, + ): + is_grad = is_grad_enabled and any( + x.requires_grad for x in [q, k, v] + ) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + head_size_og = q.size(2) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = _flash_attn_forward( + q, + k, + v, + dropout_p, + softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + alibi_slopes=alibi_slopes, + return_lse=return_lse, + return_softmax=return_softmax and dropout_p > 0.0, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + if is_grad: + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k) + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.fused_backward = fused_backward + out = out_padded[..., :head_size_og] + + result = [out] + if return_lse: + result.append(softmax_lse) + if return_softmax: + result.append(S_dmask) + + return tuple(result) + + @staticmethod + def backward(ctx, do, *args): + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + dq, dk, dv = torch.zeros_like(q), torch.empty_like(k), torch.empty_like(v) + head_size_og = do.size(2) + do_padded = do + if head_size_og % 8 != 0: + do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_og % 8]) + _flash_attn_backward( + do_padded, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + ctx.softmax_scale, + ctx.alibi_slopes, + ctx.causal, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_k=ctx.max_seqlen_k, + dropout_p=ctx.dropout_p, + philox_seed=ctx.philox_seed, + philox_offset=ctx.philox_offset, + fused=ctx.fused_backward, + ) + dq = dq[..., : q.shape[-1]] # We could have padded the head dimension + dk = dk[..., : k.shape[-1]] + dv = dv[..., : v.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + + +def flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1,-1), + alibi_slopes=None, + deterministic=False, + return_lse=False, + return_attn_probs=False, + block_table=None, + fused_backward=False, +): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnVarlenFunc.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_lse, + return_attn_probs, + block_table, + torch.is_grad_enabled(), + fused_backward, + ) + + +class FlashAttnVarlenFP8Func(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_lse, + return_softmax, + block_table, + is_grad_enabled, + ): + is_grad = is_grad_enabled and any( + x.requires_grad for x in [q, k, v] + ) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + head_size_og = q.size(2) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + + # cast input to fp8 + fp8_dtype = torch.float8_e4m3fnuz + q_fp8, descale_q = cast_varlen_to_fp8(q, fp8_dtype, cu_seqlens=cu_seqlens_q) + k_fp8, descale_k = cast_varlen_to_fp8(k, fp8_dtype, cu_seqlens=cu_seqlens_k) + v_fp8, descale_v = cast_varlen_to_fp8(v, fp8_dtype, cu_seqlens=cu_seqlens_k) + + out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = _flash_attn_forward( + q_fp8, + k_fp8, + v_fp8, + dropout_p, + softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + alibi_slopes=alibi_slopes, + return_lse=return_lse, + return_softmax=return_softmax and dropout_p > 0, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v + ) + if is_grad: + ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, descale_q, descale_k, descale_v) + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + out = out_padded[..., :head_size_og] + result = [out] + if return_lse: + result.append(softmax_lse) + if return_softmax: + result.append(S_dmask) + + return tuple(result) + + @staticmethod + def backward(ctx, do, *args): + q_fp8, k_fp8, v_fp8, out, softmax_lse, cu_seqlens_q, cu_seqlens_q, descale_q, descale_k, descale_v = ctx.saved_tensors + dq, dk, dv = torch.zeros_like(q, dtype=torch.float32), torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32) + head_size_v_og = do.size(3) + do_padded = do + if head_size_v_og % 8 != 0: + do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) + + fp8_dtype = torch.float8_e4m3fnuz + do_padded_fp8, descale_do = cast_varlen_to_fp8(dout_padded, fp8_dtype, "thd", cu_seqlens_q) + + _flash_attn_backward( + do_padded_fp8, + q_fp8, + k_fp8, + v_fp8, + out, + softmax_lse, + dq, + dk, + dv, + ctx.softmax_scale, + ctx.alibi_slopes, + ctx.causal, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=ctx.dropout_p, + philox_seed=ctx.philox_seed, + philox_offset=ctx.philox_offset, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_do=descale_do + ) + dq = dq[..., : q.shape[-1]] # We could have padded the head dimension + dk = dk[..., : k.shape[-1]] + dv = dv[..., : v.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None + +def flash_attn_varlen_fp8_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_lse=False, + return_attn_probs=False, + block_table=None +): + return FlashAttnVarlenFP8Func.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_lse, + return_attn_probs, + block_table, + torch.is_grad_enabled() + ) \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py new file mode 100644 index 00000000000..3f650d288db --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py @@ -0,0 +1,1091 @@ +import torch +import triton # type: ignore +import triton.language as tl # type: ignore +from typing import Literal, Optional +from .utils import AUTOTUNE, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, get_shapes_from_layout, compute_fp8_scaling_factors, \ + get_strides_from_layout, create_dropout_mask, create_dropout_mask_varlen, is_cdna, is_rdna + +# NOTE: triton fails to import tl.constexprs so create them here for the file +tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) +tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) + + +def get_autotune_configs(): + if False: + if is_cdna(): + # shared meta-parameters + NUM_STAGES = 1 + NUM_WARPS = 4 + WAVES_PER_EU = 2 + MATRIX_INSTR_NONKDIM = 16 + + preprocess_autotune_configs = [ + triton.Config({"PRE_BLOCK": 128, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), # og config + triton.Config({"PRE_BLOCK": 64, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + triton.Config({"PRE_BLOCK": 32, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + triton.Config({"PRE_BLOCK": 16, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + ] + preprocess_autotune_keys = [ + "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", + "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + ] + causal_autotune_configs = [ + triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 32, "BLK_SLICE_FACTOR": 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), # og config + triton.Config({'BLOCK_M1': 16, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 16, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + triton.Config({'BLOCK_M1': 16, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 16, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + ] + causal_autotune_keys = [ + "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", + "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + ] + noncausal_autotune_configs = [ + triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 32, "BLK_SLICE_FACTOR": 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), # og config + triton.Config({'BLOCK_M1': 16, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 16, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + triton.Config({'BLOCK_M1': 16, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 16, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + ] + noncausal_autotune_keys = [ + "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", + "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + ] + + return (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) + else: + raise ValueError("Unknown Device Type") + else: + # meta-parameters + # TODO: fix num_stages later + NUM_WARPS, NUM_STAGES = 4, 1 + WAVES_PER_EU = 1 + PRE_BLOCK = 128 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLK_SLICE_FACTOR = 2 + + assert BLOCK_N1 == BLOCK_M2 + + # configs for the kernels + preprocess_autotune_configs = [ + triton.Config({"PRE_BLOCK": PRE_BLOCK, "waves_per_eu": WAVES_PER_EU}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + ] + preprocess_autotune_keys = [ + "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", + "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + ] + causal_autotune_configs = [ + triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + ] + causal_autotune_keys = [ + "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", + "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + ] + noncausal_autotune_configs = [ + triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + ] + noncausal_autotune_keys = [ + "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", + "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + ] + return (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) + + + +(preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) = get_autotune_configs() + + +# This function computes delta given output Out and gradient DO +# Here is the I/O shape: +# Out: (batch, nhead_q, max_seqlens_q, headDim) +# DO: (batch, nhead_q, max_seqlens_q, headDim) +# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at +# fwd_prefill.py line 607 +@triton.autotune( + configs=preprocess_autotune_configs, + key=preprocess_autotune_keys, + use_cuda_graph=True, +) +@triton.jit +def _bwd_preprocess( + O, DO, # noqa: E741 + Delta, + stride_ob, stride_oh, stride_om, stride_ok, + stride_deltab, stride_deltah, stride_deltam, + stride_descale_do_z, + cu_seqlens_q, max_seqlen_q, + Descale_do, + PRE_BLOCK: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr +): + pid_m = tl.program_id(0) + bid = tl.program_id(1) + hid = tl.program_id(2) + # Handle varlen + q_start = 0 + seqlen_q = max_seqlen_q + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + seqlen_q = q_end - q_start + else: + q_start = 0 + seqlen_q = max_seqlen_q + + # Compute offsets + offs_m = pid_m * PRE_BLOCK + tl.arange(0, PRE_BLOCK) + offs_k = tl.arange(0, HEAD_DIM) + # Offset O/DO by batch, head and q_start + O += bid * stride_ob + hid * stride_oh + q_start * stride_om # noqa: E741 + DO += bid * stride_ob + hid * stride_oh + q_start * stride_om + # create masks + mask_m = offs_m < seqlen_q + mask_md = mask_m[:, None] + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + if PADDED_HEAD: + mask_md &= offs_k[None, :] < ACTUAL_HEAD_DIM + # compute pointers + offs_do = offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok + out_ptrs = O + offs_do + do_ptrs = DO + offs_do + # load + o = tl.load(out_ptrs, mask=mask_md, other=0.0) + do = tl.load(do_ptrs, mask=mask_md, other=0.0) + # compute and write-back to delta + if IS_FP8: + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hid) + + # NOTE: do is in the fp8 range and o is not in fp8 + delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) + else: + delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) + delta_offset = Delta + bid * stride_deltab + hid * stride_deltah + q_start * stride_deltam + tl.store(delta_offset + offs_m * stride_deltam, delta, mask=mask_m) + + +# The main inner-loop logic for computing dK and dV. +@triton.jit +def _bwd_dkdv_inner( + dk, dv, # output + Q, k, v, DO, M, D, sm_scale, # input tensor + stride_qm, stride_qk, + stride_dom, stride_dok, + stride_dropoutm, stride_dropoutn, # + stride_deltam, + BLOCK_M: tl.constexpr, # 16 + BLOCK_N: tl.constexpr, # 128 + HEAD_DIM: tl.constexpr, # + ACTUAL_HEAD_DIM: tl.constexpr, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + # Filled in by the wrapper. + start_n, start_m, num_steps, # iteration numbers + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user + MASK: tl.constexpr, # causal masking, only apply to tiles on mask diagonal + ENABLE_DROPOUT: tl.constexpr, # activate dropout + USE_EXP2: tl.constexpr, # activate exp2 + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # if HEAD_DIM is padded + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) # start_m + (0, 15) + offs_n = start_n + tl.arange(0, BLOCK_N) # start_m + (0, 127) + offs_k = tl.arange(0, HEAD_DIM) + # mask to make sure not OOB of seqlen_q + mask_n = offs_n < seqlen_k + # Q and DO are (seqlen_q, head_dim) + # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM, 1), transpose of q + qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k[:, None] * stride_qk + # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM), NOT transposed + do_ptrs = DO + offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + # BLOCK_N must be a multiple of BLOCK_M, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N % BLOCK_M == 0) + curr_m = start_m + step_m = BLOCK_M + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + + for blk_idx in range(num_steps): + if DEBUG_TRITON: print(f"iter {blk_idx}: curr_m = {curr_m}") # noqa: E701 + offs_m = curr_m + tl.arange(0, BLOCK_M) + # update the mask because offs_m advanced + mask_m = offs_m < seqlen_q + mask_qT = mask_m[None, :] + mask_do = mask_m[:, None] + mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) + if PADDED_HEAD: + mask_qT &= offs_k[:, None] < ACTUAL_HEAD_DIM + mask_do &= offs_k[None, :] < ACTUAL_HEAD_DIM + qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + # generate dropout mask + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = curr_philox_offset + \ + offs_m[None, :] * stride_dropoutm + \ + offs_n[:, None] * stride_dropoutn + if tl_DROPOUT_USE_PYTORCH: + dropout_offs = offs_m[None, :] * stride_dropoutm + \ + offs_n[:, None] * stride_dropoutn + dropout_mask = tl.load( + curr_dropout_offset + dropout_offs, + mask=mask_nm + ) + else: + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1.0 / (1 - dropout_p) + # Load m before computing qk to reduce pipeline stall. + m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) + if IS_FP8: + qkT = (tl.dot(k, qT) * descale_q * descale_k) + else: + qkT = tl.dot(k, qT) + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"qT: {qT.shape}\n", qT) + print(f"k: {k.shape}\n", k) + print(f"qkT scaled: {qkT.shape}\n", qkT * sm_scale) + # TODO: remove the scaling of m later when we removed re-scaling in fwd + if USE_EXP2: + pT = tl.math.exp2(qkT * sm_scale * RCP_LN2 - m[None, :] * RCP_LN2) + else: + pT = tl.math.exp(qkT * sm_scale - m[None, :]) + + # Autoregressive masking. + if MASK: + # offset offs_m with delta_qk since the causal mask starts at + # bottom right of the (seqlen_q, seqlen_k) matrix + causal_mask = (offs_m[None, :] - delta_qk) >= offs_n[:, None] + mask = causal_mask & mask_nm + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"causal_mask: {causal_mask.shape}\n", causal_mask) + print(f"qkT after causal: {qkT.shape}\n", tl.where(causal_mask, qkT * sm_scale, 0.0)) + pT = tl.where(mask, pT, 0.0) + do = tl.load(do_ptrs, mask=mask_do, other=0.0) + # Compute dV. + if ENABLE_DROPOUT: + pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + if IS_FP8: + scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors(pT_dropout, FP8_MAX) + dv += (tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do)* descale_p_dropout * descale_do) + else: + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + else: + if IS_FP8: + scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) + dv += (tl.dot((pT * scale_pT).to(do.type.element_ty), do) * descale_pT * descale_do) + else: + dv += tl.dot(pT.to(do.type.element_ty), do) + + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"pT: {pT.shape}\n", pT) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) + # Compute dP and dS. + if IS_FP8: + dpT = (tl.dot(v, tl.trans(do)) * descale_v * descale_do) + else: + dpT = tl.dot(v, tl.trans(do)) + if ENABLE_DROPOUT: + dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + delta_i = Di[None, :] + dsT = pT * (dpT - delta_i) + if IS_FP8: + scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) + dk += (tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) * descale_dsT * descale_q) + else: + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + # Increment pointers. + curr_m += step_m + qT_ptrs += step_m * stride_qm + do_ptrs += step_m * stride_dom + return dk, dv + +# the main inner-loop logic for computing dQ +@triton.jit +def _bwd_dq_inner( + dq, # output + q, K, V, do, m, Delta, sm_scale, # input + # shared by Q/K/V. + stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, + stride_dropoutm, stride_dropoutn, # stride for dropout + stride_deltam, + seqlen_q, seqlen_k, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + # Filled in by the wrapper. + start_m, start_n, end_n, num_steps, # + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # if HEAD_DIM is padded + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k = tl.arange(0, HEAD_DIM) + + # mask to make sure not OOB of seqlen_q + mask_m = offs_m < seqlen_q + + kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk + vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + for blk_idx in range(num_steps): + if DEBUG_TRITON: print(f"iter {blk_idx}: curr_n = {curr_n}") # noqa: E701 + offs_n = curr_n + tl.arange(0, BLOCK_N2) + # end_n is needed because the end of causal True might not be perfectly + # aligned with the end of the block + mask_n = offs_n < end_n + if DEBUG_TRITON_DETAIL: print(f"start_n = {start_n}, end_n = {end_n}, offs_n: {offs_n.shape}\n{offs_n}") # noqa: E701 + if DEBUG_TRITON_DETAIL: print(f"mask_n: {mask_n.shape}\n{mask_n}") # noqa: E701 + mask_kT = mask_n[None, :] + mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) + if PADDED_HEAD: + mask_kT &= offs_k[:, None] < ACTUAL_HEAD_DIM + + kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) + vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) + + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = curr_philox_offset + \ + offs_m[:, None] * stride_dropoutm + \ + offs_n[None, :] * stride_dropoutn + if tl_DROPOUT_USE_PYTORCH: + dropout_offs = offs_m[:, None] * stride_dropoutm + \ + offs_n[None, :] * stride_dropoutn + dropout_mask = tl.load( + curr_dropout_offset + dropout_offs, + mask=mask_mn) + else: + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1 / (1 - dropout_p) + + if IS_FP8: + qk = (tl.dot(q, kT) * descale_q * descale_k) + else: + qk = tl.dot(q, kT) + if DEBUG_TRITON_DETAIL: print(f"qk scaled: {qk.shape}\n", qk * sm_scale) # noqa: E701 + if USE_EXP2: + p = tl.math.exp2(qk * sm_scale * RCP_LN2 - m * RCP_LN2) + else: + p = tl.math.exp(qk * sm_scale - m) + + # Autoregressive masking. + if MASK: + causal_mask = (offs_m[:, None] - delta_qk) >= offs_n[None, :] + mask = causal_mask & mask_mn + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + if IS_FP8: + dp = (tl.dot(do, vT) * descale_do * descale_v) + else: + dp = tl.dot(do, vT) + if ENABLE_DROPOUT: + dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale + delta_i = Di[:, None] + ds = p * (dp -delta_i) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + if IS_FP8: + scale_ds, descale_ds = compute_fp8_scaling_factors(ds, FP8_MAX) + dq += (tl.dot((ds * scale_ds).to(kT.type.element_ty), tl.trans(kT)) * descale_ds * descale_k) + else: + dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) + # Increment pointers. + curr_n += step_n + kT_ptrs += step_n * stride_kn + vT_ptrs += step_n * stride_vn + return dq + +@triton.autotune( + configs=causal_autotune_configs, + key=causal_autotune_keys, + use_cuda_graph=True, +) +@triton.jit +def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nheads_q) + Q, K, V, sm_scale, DO, DQ, DK, DV, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + HQ, HK, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset_base, + BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_EXP2: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + delta_qk = seqlen_q - seqlen_k + if DEBUG_TRITON: print(f"delta_qk = {delta_qk}") # noqa: E701 + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + offs_k = tl.arange(0, HEAD_DIM) + GROUP_SIZE: tl.constexpr = HQ // HK + + # align the delta_qk + start_n = pid * BLOCK_N1 + if start_n < seqlen_k: + # This section does dk and dv + dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + + # q > k: diretcly skip all the way until the start of causal block + start_delta_q_gt_k = delta_qk + # q < k: some blocks will have no Masked block, other needs to re-calc + # starting position + # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the + # masked op + num_blocks_skip = -delta_qk // BLOCK_N1 + delta_aligned = (num_blocks_skip + 1) * BLOCK_N1 + delta_qk + start_delta_q_lt_k = delta_aligned // BLOCK_M1 * BLOCK_M1 + if delta_qk >= 0: + start_delta = delta_qk + if DEBUG_TRITON: print(f"q >= k: start_delta = delta_qk aligned to BLOCK_M = {start_delta_q_gt_k}") # noqa: E701 + else: + start_delta = start_delta_q_lt_k + if DEBUG_TRITON: print(f"q < k: start_delta = residue btw multiple BLOCK_N and delta_qk = {delta_aligned} = aligned to BLOCK_M = {start_delta_q_lt_k}") # noqa: E701 + + offs_n = start_n + tl.arange(0, BLOCK_N1) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_kv &= mask_k[None, :] + offs_kv = offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk + + # K/V tensors not changed for the group + adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + adj_kv + offs_kv, mask=mask_kv, other=0.0) + v = tl.load(V + adj_kv + offs_kv, mask=mask_kv, other=0.0) + # If MQA / GQA, set the K and V head offsets appropriately. + # hqid = hkid + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + if delta_qk >= 0: + start_m = start_n + start_delta + len_m = BLOCK_N1 + else: + start_m = max(start_n + delta_qk, 0) + start_m = start_m // BLOCK_M1 * BLOCK_M1 + # because we might shift the masked blocks up, we are deeper into + # the masked out region, so we would potentially increase the total + # steps with masked operation to get out of it + residue_m = max(start_n + delta_qk - start_m, 0) + len_m = BLOCK_N1 + residue_m + if DEBUG_TRITON: print(f"residue_m = {residue_m}") # noqa: E701 + + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = bid * stride_deltab + hqid * stride_deltah + \ + q_start * stride_deltam + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = dropout_mask + bid * stride_dropoutb + \ + hqid * stride_dropouth + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + # bound the masked operation to q len so it does not have to wast cycles + len_m = min(len_m, seqlen_q) + num_steps = tl.cdiv(len_m, MASK_BLOCK_M1) + # when q < k, we may skip the initial masked op + if pid < num_blocks_skip: + num_steps = 0 + + # if start_m is negative, the current N-tile has no block on the + # diagonal of causal mask, so everything have no causal mask + if DEBUG_TRITON: print(f"Masked: start_n: {start_n}; start_m: {start_m}, num_steps: {num_steps}") # noqa: E701 + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors + stride_qm, stride_qk, # strides for q + stride_dom, stride_dok, # strides for o + stride_dropoutm, stride_dropoutn, # strides for dropout + stride_deltam, + MASK_BLOCK_M1, BLOCK_N1, # block dim + HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + None, None, None, None, + MASK=True, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_EXP2=USE_EXP2, + IS_FP8=False, + FP8_MAX=None, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + start_m += num_steps * MASK_BLOCK_M1 + num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M1) + end_m = start_m + num_steps * BLOCK_M1 + + if DEBUG_TRITON: print(f"start_m after Masked step: {start_m}; num_steps: {num_steps}") # noqa: E701 + if DEBUG_TRITON: print(f"unMasked: start_n: {start_n}, start_m: {start_m}, end_m: {end_m}, num_steps: {num_steps}") # noqa: E701 + if DEBUG_TRITON: print("unMasked") # noqa: E701 + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors + stride_qm, stride_qk, # strides for q + stride_dom, stride_dok, # strides for o + stride_dropoutm, stride_dropoutn, # strides for dropout + stride_deltam, + BLOCK_M1, BLOCK_N1, # block dim + HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + None, None, None, None, + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_EXP2=USE_EXP2, + IS_FP8=False, + FP8_MAX=None, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # end of GQA/MQA of dkdv + # Write back dV and dK. + adj_dkdv = bid * stride_dkb + hkid * stride_kh + k_start * stride_dkn + offs_dkdv = offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk + tl.store(DV + adj_dkdv + offs_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv + offs_dkdv, dk, mask=mask_kv) + + # This part does dq + start_m = pid * BLOCK_M2 + if start_m < seqlen_q: + # seqlen_q > seqlen_k, no need to process these tile for dq + if DEBUG_TRITON: print(f"end_n = start_m + BLOCK_M = {start_m} + {BLOCK_M2} = {start_m + BLOCK_M2}") # noqa: E701 + if start_m + BLOCK_M2 < delta_qk: + if DEBUG_TRITON: print(f"start_m + BLOCK_M2 = {start_m} + {BLOCK_M2} = {start_m + BLOCK_M2} < delta_qk of {delta_qk}") # noqa: E701 + return + + offs_m = start_m + tl.arange(0, BLOCK_M2) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + K += adj_kv + V += adj_kv + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front + # for every M-tile + end_n = start_m + BLOCK_M2 - delta_qk + # clamp end_n at [0, seqlen_k] + end_n = max(min(end_n, seqlen_k), 0) + if DEBUG_TRITON: print(f"delta_qk: {delta_qk}; end_n: {end_n}") # noqa: E701 + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = \ + bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + Delta_ptr = Delta + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + \ + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = \ + dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(M + adj_delta + offs_m * stride_deltam, + mask=offs_m < seqlen_q) + m = m[:, None] + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + # start can only be 0 at minimum + start_n = max(end_n - BLOCK_M2, 0) + num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N2) + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + dq = _bwd_dq_inner( + dq, + q, K, V, do, m, Delta_ptr, sm_scale, # + stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, + stride_dropoutm, stride_dropoutn, # + stride_deltam, + seqlen_q, seqlen_k, # + BLOCK_M2, MASK_BLOCK_N2, # + HEAD_DIM, ACTUAL_HEAD_DIM, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + start_m, start_n, end_n, num_steps, # + None, None, None, None, + MASK=True, # + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_EXP2=USE_EXP2, + IS_FP8=False, + FP8_MAX=None, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + end_n -= num_steps * MASK_BLOCK_N2 + num_steps = tl.cdiv(end_n, BLOCK_N2) + start_n = max(end_n - num_steps * BLOCK_N2, 0) + if DEBUG_TRITON: print(f"unMasked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}") # noqa: E701 + dq = _bwd_dq_inner( + dq, # + q, K, V, do, m, Delta_ptr, sm_scale, # + stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, # + stride_dropoutm, stride_dropoutn, # + stride_deltam, + seqlen_q, seqlen_k, # + BLOCK_M2, BLOCK_N2, # + HEAD_DIM, ACTUAL_HEAD_DIM, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + start_m, start_n, end_n, num_steps, # + None, None, None, None, + MASK=False, # + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_EXP2=USE_EXP2, + IS_FP8=False, + FP8_MAX=None, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # Write back dQ. + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + # end of GQA/MQA of dq + +@triton.autotune( + configs=noncausal_autotune_configs, + key=noncausal_autotune_keys, + use_cuda_graph=True, +) +@triton.jit +def bwd_kernel_noncausal( + Q, K, V, sm_scale, DO, DQ, DK, DV, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + HQ, HK, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset_base, + BLOCK_M1: tl.constexpr, # 32 + BLOCK_N1: tl.constexpr, # 128 + BLOCK_M2: tl.constexpr, # 128 + BLOCK_N2: tl.constexpr, # 32 + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_EXP2: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + offs_k = tl.arange(0, HEAD_DIM) + GROUP_SIZE: tl.constexpr = HQ // HK + + start_n = pid * BLOCK_N1 + if start_n < seqlen_k: + dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + + offs_n = start_n + tl.arange(0, BLOCK_N1) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_kv &= mask_k[None, :] + offs_kv = offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk + + # K/V tensors not changed for the group + adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + adj_kv + offs_kv, mask=mask_kv, other=0.0) + v = tl.load(V + adj_kv + offs_kv, mask=mask_kv, other=0.0) + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = dropout_mask + bid * stride_dropoutb + \ + hqid * stride_dropouth + + # because there is no causal, we always start from the beginning + start_m = 0 + num_steps = tl.cdiv(seqlen_q, BLOCK_M1) + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors + stride_qm, stride_qk, # strides for q + stride_dom, stride_dok, # strides for o + stride_dropoutm, stride_dropoutn, # strides for dropout + stride_deltam, + BLOCK_M1, BLOCK_N1, # block dim + HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + None, None, None, None, + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_EXP2=USE_EXP2, + IS_FP8=False, + FP8_MAX=None, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + # Write back dV and dK. + adj_dkdv = bid * stride_dkb + hkid * stride_kh + k_start * stride_dkn + offs_dkdv = offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk + tl.store(DV + adj_dkdv + offs_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv + offs_dkdv, dk, mask=mask_kv) + + # THIS PART DOES DQ + start_m = pid * BLOCK_M2 + if start_m < seqlen_q: + offs_m = start_m + tl.arange(0, BLOCK_M2) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + K += adj_kv + V += adj_kv + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = \ + bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + Delta_ptr = Delta + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + \ + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = \ + dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(M + adj_delta + offs_m * stride_deltam, + mask=offs_m < seqlen_q) + m = m[:, None] + + # start can only be 0 at minimum + start_n = 0 + end_n = seqlen_k + num_steps = tl.cdiv(seqlen_k, BLOCK_N2) + + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + dq = _bwd_dq_inner( + dq, # + q, K, V, do, m, Delta_ptr, sm_scale, # + stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, # + stride_dropoutm, stride_dropoutn, # + stride_deltam, + seqlen_q, seqlen_k, # + BLOCK_M2, BLOCK_N2, # + HEAD_DIM, ACTUAL_HEAD_DIM, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + start_m, start_n, end_n, num_steps, # + None, None, None, None, + MASK=False, # + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_EXP2=USE_EXP2, + IS_FP8=False, + FP8_MAX=None, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # Write back dQ. + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + + +def attention_prefill_backward_triton_split_oneKernel_impl( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + softmax_lse: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + layout: Literal["bshd", "bhsd", "thd"], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + dropout_p: float, + philox_seed: Optional[int], + philox_offset: Optional[int], + use_exp2: bool, +): + # debug + DEBUG_TRITON: bool = False + DEBUG_TRITON_DETAIL: bool = False + + # get strides and shape + batch, nheads_q, nheads_k, head_size, max_seqlen_q_final, max_seqlen_k_final = \ + get_shapes_from_layout( + q, k, layout, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k + ) + q_strides, k_strides, v_strides, o_strides = \ + get_strides_from_layout(q, k, v, o, layout) + stride_qb, stride_qh, stride_qm, stride_qk = q_strides + stride_kb, stride_kh, stride_kn, stride_kk = k_strides + stride_vb, stride_vh, stride_vn, stride_vk = v_strides + stride_ob, stride_oh, stride_om, stride_ok = o_strides + dq_strides, dk_strides, _, do_strides = \ + get_strides_from_layout(dq, dk, dv, do, layout) + stride_dqb, stride_dqh, stride_dqm, stride_dqk = dq_strides + stride_dkb, stride_dkh, stride_dkn, stride_dkk = dk_strides + stride_dob, stride_doh, stride_dom, stride_dok = do_strides + IS_VARLEN = layout == "thd" + use_dropout = (dropout_p > 0.0) + + # get closest power of 2 over or equal to 32. + padded_d_model = 1 << (head_size - 1).bit_length() + padded_d_model = max(padded_d_model, 16) + HEAD_DIM = padded_d_model + ACTUAL_HEAD_DIM = head_size + + # init delta + delta = torch.empty_like(softmax_lse) + if IS_VARLEN: + stride_deltab = 0 + stride_deltam, stride_deltah = delta.stride() + else: + stride_deltab, stride_deltah, stride_deltam = delta.stride() + pre_grid = lambda META: (triton.cdiv(max_seqlen_q_final, META['PRE_BLOCK']), batch, nheads_q) + _bwd_preprocess[pre_grid]( + o, do, + delta, + stride_ob, stride_oh, stride_om, stride_ok, + stride_deltab, stride_deltah, stride_deltam, + 0, + cu_seqlens_q, max_seqlen_q_final, + None, + HEAD_DIM=HEAD_DIM, + ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, + IS_VARLEN=IS_VARLEN, + IS_FP8=False + ) + + # dropout mask tensor for debugging. We dump the dropout mask created in + # the kernel for testing + dropout_mask = None + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ + (0, 0 , 0 , 0) + if use_dropout: + dropout_mask = torch.zeros( + (batch, nheads_q, max_seqlen_q_final, max_seqlen_k_final), + device=q.device, + dtype=torch.float32 + ) + + if DROPOUT_USE_PYTORCH: + if not IS_VARLEN: + dropout_mask = create_dropout_mask( + dropout_p, + (batch, nheads_q, max_seqlen_q_final, max_seqlen_k_final), + seed = philox_seed + ) + else: + dropout_mask = create_dropout_mask_varlen( + dropout_p, batch, nheads_q, + cu_seqlens_q, cu_seqlens_k, philox_seed + ) + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ + dropout_mask.stride() + + seqlen = max(max_seqlen_q_final, max_seqlen_k_final) + grid = lambda META: ((seqlen + META['BLOCK_N1'] - 1) // META['BLOCK_N1'], batch, nheads_k) + if causal: + if DEBUG_TRITON: print(f"bwd_kernel: grid = {grid}" ) # noqa: E701 + bwd_kernel_causal[grid]( + q, k, v, sm_scale, do, dq, dk, dv, + softmax_lse, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + nheads_q, nheads_k, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q_final, max_seqlen_k_final, + dropout_mask, dropout_p, philox_seed, philox_offset, + HEAD_DIM=HEAD_DIM, + ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_EXP2=use_exp2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + else: + bwd_kernel_noncausal[grid]( + q, k, v, sm_scale, do, dq, dk, dv, + softmax_lse, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + nheads_q, nheads_k, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q_final, max_seqlen_k_final, + dropout_mask, dropout_p, philox_seed, philox_offset, + HEAD_DIM=HEAD_DIM, + ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_EXP2=use_exp2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + return delta \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py new file mode 100644 index 00000000000..5cc93edc5e4 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py @@ -0,0 +1,1354 @@ +import torch +import triton # type: ignore +import triton.language as tl # type: ignore +from typing import Literal, Optional +from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, compute_fp8_scaling_factors, get_shapes_from_layout, \ + get_strides_from_layout, create_dropout_mask, create_dropout_mask_varlen, is_fp8 + +# NOTE: triton fails to import tl.constexprs so create them here for the file +tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) +tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) + +# This function computes delta given output Out and gradient DO +# Here is the I/O shape: +# Out: (batch, nhead_q, max_seqlens_q, headDim) +# DO: (batch, nhead_q, max_seqlens_q, headDim) +# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at +# fwd_prefill.py line 607 +@triton.jit +def _bwd_preprocess( + O, DO, # noqa: E741 + Delta, + stride_ob, stride_oh, stride_om, stride_ok, + stride_deltab, stride_deltah, stride_deltam, + stride_descale_do_z, + cu_seqlens_q, max_seqlen_q, + Descale_do, + BLOCK_M: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr +): + pid_m = tl.program_id(0) + bid = tl.program_id(1) + hid = tl.program_id(2) + # Handle varlen + q_start = 0 + seqlen_q = max_seqlen_q + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + seqlen_q = q_end - q_start + else: + q_start = 0 + seqlen_q = max_seqlen_q + + # Compute offsets + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, HEAD_DIM) + # Offset O/DO by batch, head and q_start + O += bid * stride_ob + hid * stride_oh + q_start * stride_om # noqa: E741 + DO += bid * stride_ob + hid * stride_oh + q_start * stride_om + # create masks + mask_m = offs_m < seqlen_q + mask_md = mask_m[:, None] + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + if PADDED_HEAD: + mask_md &= offs_k[None, :] < ACTUAL_HEAD_DIM + # compute pointers + offs_do = offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok + out_ptrs = O + offs_do + do_ptrs = DO + offs_do + # load + o = tl.load(out_ptrs, mask=mask_md, other=0.0) + do = tl.load(do_ptrs, mask=mask_md, other=0.0) + # compute and write-back to delta + if IS_FP8: + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hid) + + # NOTE: do is in the fp8 range and o is not in fp8 + delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) + else: + delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) + delta_offset = Delta + bid * stride_deltab + hid * stride_deltah + q_start * stride_deltam + tl.store(delta_offset + offs_m * stride_deltam, delta, mask=mask_m) + + +# The main inner-loop logic for computing dK and dV. +@triton.jit +def _bwd_dkdv_inner( + dk, dv, # output + Q, k, v, DO, M, D, sm_scale, # input tensor + stride_qm, stride_qk, + stride_dom, stride_dok, + stride_dropoutm, stride_dropoutn, + stride_deltam, + BLOCK_M: tl.constexpr, # 16 + BLOCK_N: tl.constexpr, # 128 + HEAD_DIM: tl.constexpr, # + ACTUAL_HEAD_DIM: tl.constexpr, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, + seqlen_q, seqlen_k, # max sequence length for q and k + # Filled in by the wrapper. + start_n, start_m, num_steps, # iteration numbers + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user + MASK: tl.constexpr, # causal masking, only apply to tiles on mask diagonal + ENABLE_DROPOUT: tl.constexpr, # activate dropout + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, # activate exp2 + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # if HEAD_DIM is padded + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) # start_m + (0, 15) + offs_n = start_n + tl.arange(0, BLOCK_N) # start_m + (0, 127) + offs_k = tl.arange(0, HEAD_DIM) + # mask to make sure not OOB of seqlen_q + mask_n = offs_n < seqlen_k + # Q and DO are (seqlen_q, head_dim) + # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM, 1), transpose of q + qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k[:, None] * stride_qk + # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM), NOT transposed + do_ptrs = DO + offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + # BLOCK_N must be a multiple of BLOCK_M, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N % BLOCK_M == 0) + curr_m = start_m + step_m = BLOCK_M + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + + for blk_idx in range(num_steps): + if DEBUG_TRITON: print(f"iter {blk_idx}: curr_m = {curr_m}") # noqa: E701 + offs_m = curr_m + tl.arange(0, BLOCK_M) + # update the mask because offs_m advanced + mask_m = offs_m < seqlen_q + mask_qT = mask_m[None, :] + mask_do = mask_m[:, None] + mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) + if PADDED_HEAD: + mask_qT &= offs_k[:, None] < ACTUAL_HEAD_DIM + mask_do &= offs_k[None, :] < ACTUAL_HEAD_DIM + qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + # generate dropout mask + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = curr_philox_offset + \ + offs_m[None, :] * stride_dropoutm + \ + offs_n[:, None] * stride_dropoutn + if tl_DROPOUT_USE_PYTORCH: + dropout_offs = offs_m[None, :] * stride_dropoutm + \ + offs_n[:, None] * stride_dropoutn + dropout_mask = tl.load( + curr_dropout_offset + dropout_offs, + mask=mask_nm + ) + else: + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1.0 / (1 - dropout_p) + # Load m before computing qk to reduce pipeline stall. + m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) + if IS_FP8: + qkT = (tl.dot(k, qT) * descale_q * descale_k) + else: + qkT = tl.dot(k, qT) + qkT_scaled = qkT * sm_scale + + if USE_ALIBI: + relative_pos_block = offs_n[:, None] + seqlen_q - seqlen_k - offs_m[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + qkT_scaled += alibi_block + + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"qT: {qT.shape}\n", qT) + print(f"k: {k.shape}\n", k) + print(f"qkT scaled: {qkT.shape}\n", qkT_scaled) + # TODO: remove the scaling of m later when we removed re-scaling in fwd + if USE_EXP2: + pT = tl.math.exp2(qkT_scaled * RCP_LN2 - m[None, :] * RCP_LN2) + else: + pT = tl.math.exp(qkT_scaled - m[None, :]) + + # Autoregressive masking. + if MASK: + # offset offs_m with delta_qk since the causal mask starts at + # bottom right of the (seqlen_q, seqlen_k) matrix + causal_mask = (offs_m[None, :] - delta_qk) >= offs_n[:, None] + mask = causal_mask & mask_nm + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"causal_mask: {causal_mask.shape}\n", causal_mask) + print(f"qkT after causal: {qkT.shape}\n", tl.where(causal_mask, qkT * sm_scale, 0.0)) + pT = tl.where(mask, pT, 0.0) + do = tl.load(do_ptrs, mask=mask_do, other=0.0) + # Compute dV. + if ENABLE_DROPOUT: + pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + if IS_FP8: + scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors(pT_dropout, FP8_MAX) + dv += (tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do)* descale_p_dropout * descale_do) + else: + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + else: + if IS_FP8: + scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) + dv += (tl.dot((pT * scale_pT).to(do.type.element_ty), do) * descale_pT * descale_do) + else: + dv += tl.dot(pT.to(do.type.element_ty), do) + + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"pT: {pT.shape}\n", pT) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) + # Compute dP and dS. + if IS_FP8: + dpT = (tl.dot(v, tl.trans(do)) * descale_v * descale_do) + else: + dpT = tl.dot(v, tl.trans(do)) + if ENABLE_DROPOUT: + dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + delta_i = Di[None, :] + dsT = pT * (dpT - delta_i) + if IS_FP8: + scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) + dk += (tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) * descale_dsT * descale_q) + else: + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + # Increment pointers. + curr_m += step_m + qT_ptrs += step_m * stride_qm + do_ptrs += step_m * stride_dom + return dk, dv + + +# grid = (max_seqlen_k // BLOCK_N, batch, nheads_q) +@triton.jit +def _bwd_kernel_dkdv_causal( + Q, K, V, sm_scale, DO, DK, DV, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, + HQ, HK, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + Dropout_mask, dropout_p, philox_seed, philox_offset_base, + Alibi_slopes, + Descale_q, Descale_k, Descale_v, Descale_do, + BLOCK_M: tl.constexpr, # 32 + BLOCK_N: tl.constexpr, # 128 + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_OUTPUT: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) + # Figure out causal starting block since we have seqlen_q >=< seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + delta_qk = seqlen_q - seqlen_k + if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") + if DEBUG_TRITON: print(f"delta_qk = {delta_qk}") + # q > k: diretcly skip all the way until the start of causal block + start_delta_q_gt_k = delta_qk + # q < k: some blocks will have no Masked block, other needs to re-calc + # starting position + # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the + # masked op + num_blocks_skip = -delta_qk // BLOCK_N + delta_aligned = (num_blocks_skip + 1) * BLOCK_N + delta_qk + start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M + if delta_qk >= 0: + start_delta = delta_qk + if DEBUG_TRITON: print(f"q >= k: start_delta = delta_qk aligned to BLOCK_M = {start_delta_q_gt_k}") + else: + start_delta = start_delta_q_lt_k + if DEBUG_TRITON: print(f"q < k: start_delta = residue btw multiple BLOCK_N and delta_qk = {delta_aligned} = aligned to BLOCK_M = {start_delta_q_lt_k}") + # align the delta_qk + start_n = pid * BLOCK_N + + offs_k = tl.arange(0, HEAD_DIM) + offs_n = start_n + tl.arange(0, BLOCK_N) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_kv &= mask_k[None, :] + + GROUP_SIZE = HQ // HK + # K/V tensors not changed for the group + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + adj_k , mask=mask_kv, other=0.0) + v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + if delta_qk >= 0: + start_m = start_n + start_delta + len_m = BLOCK_N + else: + start_m = max(start_n + delta_qk, 0) + start_m = start_m // BLOCK_M * BLOCK_M + # because we might shift the masked blocks up, we are deeper into + # the masked out region, so we would potentially increase the total + # steps with masked operation to get out of it + residue_m = max(start_n + delta_qk - start_m, 0) + len_m = BLOCK_N + residue_m + if DEBUG_TRITON: print(f"residue_m = {residue_m}") + + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = Dropout_mask + bid * stride_dropoutb + \ + hqid * stride_dropouth + + MASK_BLOCK_M: tl.constexpr = BLOCK_M // BLK_SLICE_FACTOR + # bound the masked operation to q len so it does not have to wast cycles + len_m = min(len_m, seqlen_q) + num_steps = tl.cdiv(len_m, MASK_BLOCK_M) + # when q < k, we may skip the initial masked op + if pid < num_blocks_skip: + num_steps = 0 + + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + # if start_m is negative, the current N-tile has no block on the + # diagonal of causal mask, so everything have no causal mask + if DEBUG_TRITON: print(f"Masked: start_n: {start_n}; start_m: {start_m}, num_steps: {num_steps}") + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors + stride_qm, stride_qk, # strides for q + stride_dom, stride_dok, # strides for o + stride_dropoutm, stride_dropoutn, # strides for dropout + stride_deltam, + MASK_BLOCK_M, BLOCK_N, # block dim + HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + alibi_slope, + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user + MASK=True, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + start_m += num_steps * MASK_BLOCK_M + num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) + end_m = start_m + num_steps * BLOCK_M + + if DEBUG_TRITON: print(f"start_m after Masked step: {start_m}; num_steps: {num_steps}") # noqa: E701 + if DEBUG_TRITON: print(f"unMasked: start_n: {start_n}, start_m: {start_m}, end_m: {end_m}, num_steps: {num_steps}") # noqa: E701 + if DEBUG_TRITON: print("unMasked") # noqa: E701 + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors + stride_qm, stride_qk, # strides for q + stride_dom, stride_dok, # strides for o + stride_dropoutm, stride_dropoutn, # strides for dropout + stride_deltam, + BLOCK_M, BLOCK_N, # block dim + HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + alibi_slope, + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + # Write back dV and dK. + adj_dkdv = bid * stride_dkb + hkid * stride_kh + k_start * stride_dkn + offs_dkdv = offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk + tl.store(DV + adj_dkdv + offs_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv + offs_dkdv, dk, mask=mask_kv) + + +# the main inner-loop logic for computing dQ +@triton.jit +def _bwd_dq_inner( + dq, # output + q, K, V, do, m, Delta, sm_scale, # input + # shared by Q/K/V. + stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, + stride_dropoutm, stride_dropoutn, # stride for dropout + stride_deltam, + seqlen_q, seqlen_k, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, + # Filled in by the wrapper. + start_m, start_n, end_n, num_steps, # + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # if HEAD_DIM is padded + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k = tl.arange(0, HEAD_DIM) + + # mask to make sure not OOB of seqlen_q + mask_m = offs_m < seqlen_q + + kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk + vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + for blk_idx in range(num_steps): + if DEBUG_TRITON: print(f"iter {blk_idx}: curr_n = {curr_n}") # noqa: E701 + offs_n = curr_n + tl.arange(0, BLOCK_N2) + # end_n is needed because the end of causal True might not be perfectly + # aligned with the end of the block + mask_n = offs_n < end_n + if DEBUG_TRITON_DETAIL: print(f"start_n = {start_n}, end_n = {end_n}, offs_n: {offs_n.shape}\n{offs_n}") # noqa: E701 + if DEBUG_TRITON_DETAIL: print(f"mask_n: {mask_n.shape}\n{mask_n}") # noqa: E701 + mask_kT = mask_n[None, :] + mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) + if PADDED_HEAD: + mask_kT &= offs_k[:, None] < ACTUAL_HEAD_DIM + + kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) + vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) + + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = curr_philox_offset + \ + offs_m[:, None] * stride_dropoutm + \ + offs_n[None, :] * stride_dropoutn + if tl_DROPOUT_USE_PYTORCH: + dropout_offs = offs_m[:, None] * stride_dropoutm + \ + offs_n[None, :] * stride_dropoutn + dropout_mask = tl.load( + curr_dropout_offset + dropout_offs, + mask=mask_mn) + else: + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1 / (1 - dropout_p) + + if IS_FP8: + qk = (tl.dot(q, kT) * descale_q * descale_k) + else: + qk = tl.dot(q, kT) + qk_scaled = qk * sm_scale + + if USE_ALIBI: + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + qk_scaled += alibi_block + + if DEBUG_TRITON_DETAIL: print(f"qk scaled: {qk.shape}\n", qk_scaled) # noqa: E701 + if USE_EXP2: + p = tl.math.exp2(qk_scaled * RCP_LN2 - m * RCP_LN2) + else: + p = tl.math.exp(qk_scaled - m) + + # Autoregressive masking. + if MASK: + causal_mask = (offs_m[:, None] - delta_qk) >= offs_n[None, :] + mask = causal_mask & mask_mn + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + if IS_FP8: + dp = (tl.dot(do, vT) * descale_do * descale_v) + else: + dp = tl.dot(do, vT) + if ENABLE_DROPOUT: + dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale + delta_i = Di[:, None] + ds = p * (dp -delta_i) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + if IS_FP8: + scale_ds, descale_ds = compute_fp8_scaling_factors(ds, FP8_MAX) + dq += (tl.dot((ds * scale_ds).to(kT.type.element_ty), tl.trans(kT)) * descale_ds * descale_k) + else: + dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) + # Increment pointers. + curr_n += step_n + kT_ptrs += step_n * stride_kn + vT_ptrs += step_n * stride_vn + return dq + + +# grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nheads_q) +@triton.jit +def _bwd_kernel_dq_causal( + Q, K, V, sm_scale, DO, DQ, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, + HQ, HK, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + Dropout_mask, dropout_p, philox_seed, philox_offset_base, + Alibi_slopes, + Descale_q, Descale_k, Descale_v, Descale_do, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_OUTPUT: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + # Figure out causal starting block since we have seqlen_q <=> seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + # DQ tiles on M dim and iterate on N dim, so we there could be some tiles we + # can simply skip and we need to adjust starting position. + start_m = pid * BLOCK_M + # seqlen_q > seqlen_k, no need to process these tile for dq + delta_qk = seqlen_q - seqlen_k + if DEBUG_TRITON: print(f"end_n = start_m + BLOCK_M = {start_m} + {BLOCK_M} = {start_m + BLOCK_M}") # noqa: E701 + if start_m + BLOCK_M < delta_qk: + if DEBUG_TRITON: print(f"start_m + BLOCK_M = {start_m} + {BLOCK_M} = {start_m + BLOCK_M} < delta_qk of {delta_qk}") # noqa: E701 + return + + offs_k = tl.arange(0, HEAD_DIM) + offs_m = start_m + tl.arange(0, BLOCK_M) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + K += adj_k + V += adj_v + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE = HQ // HK + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front + # for every M-tile + end_n = start_m + BLOCK_M - delta_qk + # clamp end_n at [0, seqlen_k] + end_n = max(min(end_n, seqlen_k), 0) + if DEBUG_TRITON: print(f"delta_qk: {delta_qk}; end_n: {end_n}") # noqa: E701 + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = \ + bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + Delta_ptr = Delta + adj_delta + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + \ + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = \ + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(M + adj_delta + offs_m * stride_deltam, + mask=offs_m < seqlen_q) + m = m[:, None] + + MASK_BLOCK_N: tl.constexpr = BLOCK_N // BLK_SLICE_FACTOR + # start can only be 0 at minimum + start_n = max(end_n - BLOCK_M, 0) + num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N) + + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + dq = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + if DEBUG_TRITON: print(f"pid: {pid}; end_n: {end_n}, start_m: {start_m}") # noqa: E701 + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _bwd_dq_inner, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + if DEBUG_TRITON: print(f"Masked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}") # noqa: E701 + dq = _bwd_dq_inner( + dq, + q, K, V, do, m, Delta_ptr, sm_scale, + stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, + stride_dropoutm, stride_dropoutn, + stride_deltam, + seqlen_q, seqlen_k, + BLOCK_M, MASK_BLOCK_N, + HEAD_DIM, ACTUAL_HEAD_DIM, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, + start_m, start_n, end_n, num_steps, + descale_q, descale_k, descale_v, descale_do, + MASK=True, + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + end_n -= num_steps * MASK_BLOCK_N + num_steps = tl.cdiv(end_n, BLOCK_N) + start_n = max(end_n - num_steps * BLOCK_N, 0) + if DEBUG_TRITON: print(f"unMasked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}") # noqa: E701 + dq = _bwd_dq_inner( + dq, + q, K, V, do, m, Delta_ptr, sm_scale, + stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, + stride_dropoutm, stride_dropoutn, + stride_deltam, + seqlen_q, seqlen_k, + BLOCK_M, BLOCK_N, + HEAD_DIM, ACTUAL_HEAD_DIM, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, + start_m, start_n, end_n, num_steps, + descale_q, descale_k, descale_v, descale_do, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # Write back dQ. + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + + +@triton.jit +def _bwd_kernel_dkdv_noncausal( + Q, K, V, sm_scale, DO, DK, DV, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, + HQ, HK, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + Dropout_mask, dropout_p, philox_seed, philox_offset_base, + Alibi_slopes, + Descale_q, Descale_k, Descale_v, Descale_do, + BLOCK_M: tl.constexpr, # 32 + BLOCK_N: tl.constexpr, # 128 + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_OUTPUT: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) + + start_n = pid * BLOCK_N + + offs_k = tl.arange(0, HEAD_DIM) + offs_n = start_n + tl.arange(0, BLOCK_N) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_kv &= mask_k[None, :] + + GROUP_SIZE = HQ // HK + # K/V tensors not changed for the group + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = Dropout_mask + bid * stride_dropoutb + \ + hqid * stride_dropouth + + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + # because there is no causal, we always start from the beginning + start_m = 0 + num_steps = tl.cdiv(seqlen_q, BLOCK_M) + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors + stride_qm, stride_qk, # strides for q + stride_dom, stride_dok, # strides for o + stride_dropoutm, stride_dropoutn, # strides for dropout + stride_deltam, + BLOCK_M, BLOCK_N, # block dim + HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + alibi_slope, + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + # Write back dV and dK. + adj_dkdv = bid * stride_dkb + hkid * stride_kh + k_start * stride_dkn + offs_dkdv = offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk + tl.store(DV + adj_dkdv + offs_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv + offs_dkdv, dk, mask=mask_kv) + + +@triton.jit +def _bwd_kernel_dq_noncausal( + Q, K, V, sm_scale, DO, DQ, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, + HQ, HK, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + Dropout_mask, dropout_p, philox_seed, philox_offset_base, + Alibi_slopes, + Descale_q, Descale_k, Descale_v, Descale_do, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_OUTPUT: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + start_m = pid * BLOCK_M + + offs_k = tl.arange(0, HEAD_DIM) + offs_m = start_m + tl.arange(0, BLOCK_M) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + K += adj_k + V += adj_v + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE = HQ // HK + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = \ + bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + Delta_ptr = Delta + adj_delta + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + \ + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = \ + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(M + adj_delta + offs_m * stride_deltam, + mask=offs_m < seqlen_q) + m = m[:, None] + + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + # start can only be 0 at minimum + start_n = 0 + end_n = seqlen_k + num_steps = tl.cdiv(seqlen_k, BLOCK_N) + dq = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + dq = _bwd_dq_inner( + dq, + q, K, V, do, m, Delta_ptr, sm_scale, + stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, + stride_dropoutm, stride_dropoutn, + stride_deltam, + seqlen_q, seqlen_k, + BLOCK_M, BLOCK_N, + HEAD_DIM, ACTUAL_HEAD_DIM, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, + start_m, start_n, end_n, num_steps, + descale_q, descale_k, descale_v, descale_do, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # Write back dQ. + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + + +def attention_prefill_backward_triton_split_impl( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + softmax_lse: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + layout: Literal["bshd", "bhsd", "thd"], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + dropout_p: float, + philox_seed: Optional[int], + philox_offset: Optional[int], + use_exp2: bool, + # fp8 + descale_q: Optional[torch.Tensor], + descale_k: Optional[torch.Tensor], + descale_v: Optional[torch.Tensor], + descale_o: Optional[torch.Tensor], + descale_do: Optional[torch.Tensor], + descale_dq: Optional[torch.Tensor], + descale_dk: Optional[torch.Tensor], + descale_dv: Optional[torch.Tensor], +): + # debug + DEBUG_TRITON: bool = False + DEBUG_TRITON_DETAIL: bool = False + + # fp8 + IS_FP8 = is_fp8(q) + if IS_FP8: + FP8_MAX = torch.finfo(q.dtype).max + # assert that the main inputs are fp8 + assert is_fp8(do) and is_fp8(q) and is_fp8(k) and is_fp8(v), f"Non fp8 type found: do.dtype={do.dtype}, q.dtype={q.dtype}, k.dtype={k.dtype}, v.dtype={v.dtype}. All tensors must be fp8." + if is_fp8(o): + FP8_OUTPUT = True + assert descale_o is not None, f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor o." + assert descale_dq is not None, f"descale_dq is None. In fp8, you need to pass a tensor for descale_dq along with a tensor dq." + assert descale_dk is not None, f"descale_dk is None. In fp8, you need to pass a tensor for descale_dk along with a tensor dk." + assert descale_dv is not None, f"descale_dv is None. In fp8, you need to pass a tensor for descale_dv along with a tensor dv." + else: + FP8_OUTPUT = False + + stride_descale_q_z = descale_q.stride(0) if descale_q is not None else None + stride_descale_k_z = descale_k.stride(0) if descale_k is not None else None + stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None + stride_descale_o_z = descale_o.stride(0) if descale_o is not None else None + stride_descale_do_z = descale_do.stride(0) if descale_do is not None else None + else: + FP8_MAX = None + FP8_OUTPUT = False + stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_o_z = stride_descale_do_z = None + + + # get strides and shape + batch, nheads_q, nheads_k, head_size, max_seqlen_q_final, max_seqlen_k_final = \ + get_shapes_from_layout( + q, k, layout, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k + ) + q_strides, k_strides, v_strides, o_strides = \ + get_strides_from_layout(q, k, v, o, layout) + stride_qb, stride_qh, stride_qm, stride_qk = q_strides + stride_kb, stride_kh, stride_kn, stride_kk = k_strides + stride_vb, stride_vh, stride_vn, stride_vk = v_strides + stride_ob, stride_oh, stride_om, stride_ok = o_strides + dq_strides, dk_strides, dv_strides, do_strides = \ + get_strides_from_layout(dq, dk, dv, do, layout) + stride_dqb, stride_dqh, stride_dqm, stride_dqk = dq_strides + stride_dkb, stride_dkh, stride_dkn, stride_dkk = dk_strides + stride_dvb, stride_dvh, stride_dvn, stride_dvk = dv_strides + stride_dob, stride_doh, stride_dom, stride_dok = do_strides + IS_VARLEN = layout == "thd" + use_dropout = (dropout_p > 0.0) + use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) + + # get closest power of 2 over or equal to 32. + padded_d_model = 1 << (head_size - 1).bit_length() + padded_d_model = max(padded_d_model, 32) # NOTE: the causal path expects a min of 32. It will cause a compiler assert. + HEAD_DIM = padded_d_model + ACTUAL_HEAD_DIM = head_size + # meta-parameters + # TODO: fix num_stages later + NUM_WARPS, NUM_STAGES = 4, 1 + WAVES_PER_EU = 1 + PRE_BLOCK = 128 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLK_SLICE_FACTOR = 2 + + # init delta + delta = torch.zeros_like(softmax_lse) + if IS_VARLEN: + stride_deltab = 0 + stride_deltah, stride_deltam = delta.stride() + else: + stride_deltab, stride_deltah, stride_deltam = delta.stride() + pre_grid = (triton.cdiv(max_seqlen_q_final, PRE_BLOCK), batch, nheads_q) + _bwd_preprocess[pre_grid]( + o, do, + delta, + stride_ob, stride_oh, stride_om, stride_ok, + stride_deltab, stride_deltah, stride_deltam, + stride_descale_do_z, + cu_seqlens_q, max_seqlen_q_final, + descale_do, + BLOCK_M=PRE_BLOCK, + HEAD_DIM=HEAD_DIM, + ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8 + ) + + if DEBUG: + print("delta:", delta, delta.shape) + + # dropout mask tensor for debugging. We dump the dropout mask created in + # the kernel for testing + dropout_mask = None + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ + (0, 0 , 0 , 0) + if use_dropout: + dropout_mask = torch.zeros( + (batch, nheads_q, max_seqlen_q_final, max_seqlen_k_final), + device=q.device, + dtype=torch.float32 + ) + + if DROPOUT_USE_PYTORCH: + if not IS_VARLEN: + dropout_mask = create_dropout_mask( + dropout_p, + (batch, nheads_q, max_seqlen_q_final, max_seqlen_k_final), + seed = philox_seed + ) + else: + dropout_mask = create_dropout_mask_varlen( + dropout_p, batch, nheads_q, + cu_seqlens_q, cu_seqlens_k, philox_seed + ) + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ + dropout_mask.stride() + + grid_dkdv = ((max_seqlen_k_final + BLOCK_N1 - 1) // BLOCK_N1, batch, nheads_k) + grid_dq = ((max_seqlen_q_final + BLOCK_M2 - 1) // BLOCK_M2, batch, nheads_k) + if causal: + if DEBUG_TRITON: print(f"_bwd_kernel_dkdv: grid = {grid_dkdv}, block_size = ({BLOCK_M1, BLOCK_N1})", ) # noqa: E701 + _bwd_kernel_dkdv_causal[grid_dkdv]( + q, k, v, sm_scale, do, dk, dv, + softmax_lse, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, + nheads_q, nheads_k, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q_final, max_seqlen_k_final, + dropout_mask, dropout_p, philox_seed, philox_offset, + alibi_slopes, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M1, BLOCK_N1, BLK_SLICE_FACTOR, + HEAD_DIM, ACTUAL_HEAD_DIM, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, + USE_EXP2=use_exp2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_OUTPUT=FP8_OUTPUT, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu = WAVES_PER_EU, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + if DEBUG_TRITON: print(f"\n_bwd_kernel_dq: grid = {grid_dq}, block_size = ({BLOCK_M2, BLOCK_N2})", ) # noqa: E701 + _bwd_kernel_dq_causal[grid_dq]( + q, k, v, sm_scale, do, dq, + softmax_lse, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, + nheads_q, nheads_k, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q_final, max_seqlen_k_final, + dropout_mask, dropout_p, philox_seed, philox_offset, + alibi_slopes, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M2, BLOCK_N2, BLK_SLICE_FACTOR, + HEAD_DIM, ACTUAL_HEAD_DIM, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, + USE_EXP2=use_exp2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_OUTPUT=FP8_OUTPUT, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu = WAVES_PER_EU, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + else: + _bwd_kernel_dkdv_noncausal[grid_dkdv]( + q, k, v, sm_scale, do, dk, dv, + softmax_lse, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, + nheads_q, nheads_k, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q_final, max_seqlen_k_final, + dropout_mask, dropout_p, philox_seed, philox_offset, + alibi_slopes, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M1, BLOCK_N1, BLK_SLICE_FACTOR, + HEAD_DIM, ACTUAL_HEAD_DIM, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, + USE_EXP2=use_exp2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_OUTPUT=FP8_OUTPUT, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu = WAVES_PER_EU, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + _bwd_kernel_dq_noncausal[grid_dq]( + q, k, v, sm_scale, do, dq, + softmax_lse, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, + nheads_q, nheads_k, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q_final, max_seqlen_k_final, + dropout_mask, dropout_p, philox_seed, philox_offset, + alibi_slopes, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M2, BLOCK_N2, BLK_SLICE_FACTOR, + HEAD_DIM, ACTUAL_HEAD_DIM, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, + USE_EXP2=use_exp2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_OUTPUT=FP8_OUTPUT, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu = WAVES_PER_EU, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + return delta diff --git a/flash_attn/flash_attn_triton_amd/bwd_ref.py b/flash_attn/flash_attn_triton_amd/bwd_ref.py index 7ea7c32bf7f..90a98ce4fcc 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_ref.py +++ b/flash_attn/flash_attn_triton_amd/bwd_ref.py @@ -1,11 +1,14 @@ import torch import math -from .utils import DEBUG +from typing import Literal, Optional +from .utils import DEBUG, compute_alibi_tensor_ref + +DEBUG_CORE = False def attention_backward_core_ref_impl( - do, q, k, v, o, softmax_lse, sm_scale, causal, use_exp2 + do, q, k, v, o, softmax_lse, sm_scale, causal, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2 ): - if DEBUG: + if DEBUG_CORE: print() print("attention_backward_core_ref_impl") print("do:", do, do.shape) @@ -16,6 +19,9 @@ def attention_backward_core_ref_impl( print("softmax_lse:", softmax_lse, softmax_lse.shape) print("sm_scale:", sm_scale) print("causal:", causal) + print("dropout_p:", dropout_p) + print("philox_seed:", philox_seed) + print("philox_offset:", philox_offset) print("use_exp2:", use_exp2) # cast to float32 @@ -28,15 +34,27 @@ def attention_backward_core_ref_impl( # recompute attention_scores. Make sure it matches the forward impl. i.e. It use float32 - attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32)) - if DEBUG: + attention_scores = torch.matmul(q, k.transpose(-2, -1)) + if DEBUG_CORE: print("attention_scores:", attention_scores, attention_scores.shape) # scale scores attention_scaled_scores = sm_scale * attention_scores - if DEBUG: + if DEBUG_CORE: print("attention_scaled_scores:", attention_scaled_scores, attention_scaled_scores.shape) + if alibi_slopes is not None: + L_q, L_k = q.shape[1], k.shape[1] + if DEBUG_CORE: + print("alibi_slopes:", alibi_slopes, alibi_slopes.shape) + alibi_bias = compute_alibi_tensor_ref(alibi_slopes, L_q, L_k) + alibi_bias = alibi_bias.reshape(-1, L_q, L_k) + if True: + print("alibi_bias:", alibi_bias, alibi_bias.shape) + attention_scaled_scores = attention_scaled_scores + alibi_bias + if DEBUG_CORE: + print("attention_scaled_scores after alibi:", attention_scaled_scores, attention_scaled_scores.shape) + # Apply causal mask if necessary if causal: L_q, L_k = q.shape[1], k.shape[1] @@ -44,13 +62,13 @@ def attention_backward_core_ref_impl( col_idx = torch.arange(L_k, device=q.device).unsqueeze(0) col_offset = L_q-L_k causal_mask = row_idx >= (col_offset + col_idx) - if DEBUG: + if DEBUG_CORE: print("causal_mask:", causal_mask) # set -inf to places the causal mask is false attention_scaled_scores = attention_scaled_scores.masked_fill( torch.logical_not(causal_mask.unsqueeze(0)), float('-inf') ) - if DEBUG: + if DEBUG_CORE: print("attention_scaled_scores after causal:", attention_scaled_scores, attention_scaled_scores.shape) # compute probabilities using softmax_lse @@ -63,58 +81,79 @@ def attention_backward_core_ref_impl( else: softmax_lse_3d = softmax_lse.unsqueeze(-1) p = torch.exp(attention_scaled_scores - softmax_lse_3d) - - if DEBUG: + if DEBUG_CORE: print("softmax_lse_3d:", softmax_lse_3d, softmax_lse_3d.shape) print("p:", p, p.shape) - # compute gradient wrt v - dv = torch.matmul(p.transpose(-2, -1), do.to(torch.float32)) - if DEBUG: - print("dv:", dv, dv.shape) - # compute dp - dp = torch.matmul(do, v.transpose(-2, -1)) - if DEBUG: - print("dp:", dp, dp.shape) - - # calculate ds using dp - if True: - delta = torch.sum(o * do, axis=-1).to(torch.float32) # what OAI kernel uses - delta_3d = delta.unsqueeze(-1) + if dropout_p > 0.0: + rand_vals = torch.rand(p.shape, generator=torch.Generator(device=p.device).manual_seed(philox_seed), device=p.device, dtype=p.dtype) + dropout_mask, dropout_scale = rand_vals > dropout_p, (1.0 / (1 - dropout_p)) + if DEBUG: + print("dropout_scale:", dropout_scale) + print("dropout_mask:", dropout_mask) + + p_drop = torch.where(dropout_mask, p, torch.zeros_like(p)) + p_drop_scaled = p_drop * dropout_scale + if DEBUG_CORE: + print("dropout_scale:", dropout_scale) + print("p_drop:", p_drop, p_drop.shape) + print("p_drop_scaled:", p_drop_scaled, p_drop_scaled.shape) + + # compute dv + dv = torch.matmul(p_drop_scaled.transpose(-2, -1), do) + if DEBUG_CORE: + print("dv:", dv, dv.shape) + + # compute dp + dp_dropout = torch.matmul(do, v.transpose(-2, -1)) + dp = torch.where(dropout_mask, dp_dropout , torch.zeros_like(dp_dropout)) * dropout_scale + if DEBUG_CORE: + print("dp_dropout:", dp_dropout, dp_dropout.shape) + print("dp:", dp, dp.shape) else: - delta = torch.sum(p * dp, axis=-1) # what the math says you should use - delta_3d = delta.unsqueeze(-1) - if DEBUG: - print("delta_3d:", delta_3d, delta_3d.shape) - ds = (p * (dp - delta_3d)) * sm_scale + # compute dv + dv = torch.matmul(p.transpose(-2, -1), do) + if DEBUG_CORE: + print("dv:", dv, dv.shape) + + # compute dp + dp = torch.matmul(do, v.transpose(-2, -1)) + if DEBUG_CORE: + print("dp:", dp, dp.shape) + + # calculate ds + if False: + delta = torch.sum(o * do, axis=-1).unsqueeze(-1) + else: + delta = torch.sum(p * dp, axis=-1).unsqueeze(-1) if DEBUG: + print("delta:", delta, delta.shape) + dscores_scaled = p * (dp - delta) + ds = dscores_scaled * sm_scale + if DEBUG_CORE: + print("dscores_scaled:", dscores_scaled, dscores_scaled.shape) print("ds:", ds, ds.shape) - - # compute gradient wrt k - dk = torch.matmul(ds.transpose(-2, -1), q.to(torch.float32)) - if DEBUG: + # compute gradient wrt k & q + dk = torch.matmul(ds.transpose(-2, -1), q) + dq = torch.matmul(ds, k) + if DEBUG_CORE: print("dk:", dk, dk.shape) - - # compute gradient wrt q - dq = torch.matmul(ds, k.to(torch.float32)) - if DEBUG: print("dq:", dq, dq.shape) # cast back to original dtype dq = dq.to(torch.float16) dk = dk.to(torch.float16) dv = dv.to(torch.float16) - # remove d dim with size 1 - delta = delta_3d.squeeze(-1) + delta = delta.squeeze(-1) - if DEBUG: + if DEBUG_CORE: print("attention_backward_core_ref_impl output") - print("dq:", dq, dq.shape) - print("dk:", dk, dk.shape) - print("dv:", dv, dv.shape) print("delta:", delta, delta.shape) + print("dv:", dv, dv.shape) + print("dk:", dk, dk.shape) + print("dq:", dq, dq.shape) return dq, dk, dv, delta @@ -132,6 +171,10 @@ def attention_varlen_backward_pytorch_ref_impl( cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, use_exp2, ): # Ensure the layout is 'thd' @@ -139,8 +182,12 @@ def attention_varlen_backward_pytorch_ref_impl( raise ValueError(f"Unsupported layout {layout}. Expected 'thd'.") batch_size = cu_seqlens_q.shape[0] - 1 - num_heads = q.shape[1] - head_dim = q.shape[2] + nheads_q, head_dim = q.shape[1], q.shape[2] + nheads_k = k.shape[1] + + group_size = nheads_q // nheads_k + if nheads_q % nheads_k != 0: + raise ValueError("nheads_q must be divisible by nheads_k") # Pre-allocate outputs total_L_q = q.shape[0] @@ -149,8 +196,8 @@ def attention_varlen_backward_pytorch_ref_impl( dq = torch.zeros_like(q) dk = torch.zeros_like(k) dv = torch.zeros_like(v) - # delta has the same shape as softmax_lse: [total_L_q, num_heads] - delta = torch.zeros((total_L_q, num_heads), dtype=torch.float32, device=o.device) + # delta has the same shape as softmax_lse: [total_L_q, nheads_q] + delta = torch.zeros((total_L_q, nheads_q), dtype=torch.float32, device=o.device) for i in range(batch_size): # Get the start and end indices for the current sequence @@ -160,22 +207,41 @@ def attention_varlen_backward_pytorch_ref_impl( end_k = cu_seqlens_k[i + 1].item() # Extract q_i, k_i, v_i, do_i, o_i, softmax_lse_i - q_i = q[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim] - k_i = k[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim] - v_i = v[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim] - do_i = do[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim] - o_i = o[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim] - # softmax_lse has shape [total_L_q, num_heads] - softmax_lse_i = softmax_lse[start_q:end_q, :] # [L_q_i, num_heads] - softmax_lse_i = softmax_lse_i.transpose(0, 1) # [num_heads, L_q_i] - - # Permute to [num_heads, L_q_i, head_dim] + q_i = q[start_q:end_q, :, :] # [L_q_i, nheads_q, head_dim] + k_i = k[start_k:end_k, :, :] # [L_k_i, nheads_k, head_dim] + v_i = v[start_k:end_k, :, :] # [L_k_i, nheads_k, head_dim] + do_i = do[start_q:end_q, :, :] # [L_q_i, nheads_q, head_dim] + o_i = o[start_q:end_q, :, :] # [L_q_i, nheads_q, head_dim] + softmax_lse_i = softmax_lse[start_q:end_q, :] # [L_q_i, nheads_q] + + if group_size != 1: + # MQA or GQA case + # Reshape tensors to include group dimension + q_i = q_i.view(q_i.shape[0], nheads_k, group_size, head_dim) + do_i = do_i.view(do_i.shape[0], nheads_k, group_size, head_dim) + o_i = o_i.view(o_i.shape[0], nheads_k, group_size, head_dim) + softmax_lse_i = softmax_lse_i.view(softmax_lse_i.shape[0], nheads_k, group_size) + # Expand k_i and v_i to match group_size + k_i = k_i.unsqueeze(2).expand(-1, -1, group_size, -1) + v_i = v_i.unsqueeze(2).expand(-1, -1, group_size, -1) + # Flatten the nheads_k and group_size dimensions + q_i = q_i.reshape(q_i.shape[0], nheads_k * group_size, head_dim) + do_i = do_i.reshape(do_i.shape[0], nheads_k * group_size, head_dim) + o_i = o_i.reshape(o_i.shape[0], nheads_k * group_size, head_dim) + softmax_lse_i = softmax_lse_i.reshape(softmax_lse_i.shape[0], nheads_k * group_size) + k_i = k_i.reshape(k_i.shape[0], nheads_k * group_size, head_dim) + v_i = v_i.reshape(v_i.shape[0], nheads_k * group_size, head_dim) + # Permute to [nheads_total, L, head_dim] q_i = q_i.permute(1, 0, 2) k_i = k_i.permute(1, 0, 2) v_i = v_i.permute(1, 0, 2) do_i = do_i.permute(1, 0, 2) o_i = o_i.permute(1, 0, 2) - # softmax_lse_i is already in [num_heads, L_q_i] + softmax_lse_i = softmax_lse_i.transpose(0, 1) + if alibi_slopes is not None: + alibi_slopes_i = alibi_slopes[i] + else: + alibi_slopes_i = None # Call the core backward function for this sequence dq_i, dk_i, dv_i, delta_i = attention_backward_core_ref_impl( @@ -187,20 +253,39 @@ def attention_varlen_backward_pytorch_ref_impl( softmax_lse_i, sm_scale, causal, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes_i, use_exp2 ) # Convert back to 'thd' layout - dq_i = dq_i.permute(1, 0, 2) # [L_q_i, num_heads, head_dim] - dk_i = dk_i.permute(1, 0, 2) # [L_k_i, num_heads, head_dim] - dv_i = dv_i.permute(1, 0, 2) # [L_k_i, num_heads, head_dim] + dq_i = dq_i.permute(1, 0, 2) # [L_q_i, nheads_total, head_dim] + dk_i = dk_i.permute(1, 0, 2) # [L_k_i, nheads_total, head_dim] + dv_i = dv_i.permute(1, 0, 2) # [L_k_i, nheads_total, head_dim] + delta_i = delta_i.transpose(1, 0) # [L_q_i, nheads_total] + + if group_size != 1: + # Reshape dq_i and delta_i back to original shape + dq_i = dq_i.view(dq_i.shape[0], nheads_k, group_size, head_dim) + delta_i = delta_i.view(delta_i.shape[0], nheads_k, group_size) + # Sum dk_i and dv_i over group dimension + dk_i = dk_i.view(dk_i.shape[0], nheads_k, group_size, head_dim) + dv_i = dv_i.view(dv_i.shape[0], nheads_k, group_size, head_dim) + dk_i = dk_i.sum(dim=2) + dv_i = dv_i.sum(dim=2) + # Reshape dq_i back to [L_q_i, nheads_q, head_dim] + dq_i = dq_i.reshape(dq_i.shape[0], nheads_q, head_dim) + delta_i = delta_i.reshape(delta_i.shape[0], nheads_q) + else: + # No need to reshape + pass # Place outputs in pre-allocated tensors dq[start_q:end_q, :, :] = dq_i dk[start_k:end_k, :, :] += dk_i # Accumulate gradients for shared keys dv[start_k:end_k, :, :] += dv_i # Accumulate gradients for shared values - # delta_i has shape [num_heads, L_q_i] - delta_i = delta_i.transpose(1, 0) # [L_q_i, num_heads] delta[start_q:end_q, :] = delta_i return dq, dk, dv, delta @@ -215,6 +300,10 @@ def attention_vanilla_backward_pytorch_ref_impl( sm_scale, causal, layout, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, use_exp2, ): if layout == "bshd": @@ -231,18 +320,42 @@ def attention_vanilla_backward_pytorch_ref_impl( else: raise ValueError(f"Unknown layout {layout}") - # Prepare tensors in [batch_size * num_heads, seq_len, head_dim] format - batch_size, num_heads, seq_len_q, head_dim = q.shape - seq_len_k = k.shape[2] - - # Merge batch and heads dimensions - do = do.reshape(batch_size * num_heads, seq_len_q, head_dim) - q = q.reshape(batch_size * num_heads, seq_len_q, head_dim) - k = k.reshape(batch_size * num_heads, seq_len_k, head_dim) - v = v.reshape(batch_size * num_heads, seq_len_k, head_dim) - softmax_lse = softmax_lse.reshape(batch_size * num_heads, seq_len_q) - o = o.reshape(batch_size * num_heads, seq_len_q, head_dim) - + # Prepare tensors + batch_size, nheads_q, seq_len_q, head_dim = q.shape + batch_size, nheads_k, seq_len_k, head_dim = k.shape + + group_size = nheads_q // nheads_k + if nheads_q % nheads_k != 0: + raise ValueError("nheads_q must be divisible by nheads_k") + + if group_size != 1: + # MQA or GQA case + # Reshape do, q, o to [batch_size, nheads_k, group_size, seq_len_q, head_dim] + do = do.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) + q = q.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) + o = o.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) + # Reshape softmax_lse to [batch_size, nheads_k, group_size, seq_len_q] + softmax_lse = softmax_lse.reshape(batch_size, nheads_k, group_size, seq_len_q) + # Expand k and v to match group_size + k = k.unsqueeze(2).expand(-1, -1, group_size, -1, -1) # [batch_size, nheads_k, group_size, seq_len_k, head_dim] + v = v.unsqueeze(2).expand(-1, -1, group_size, -1, -1) + # Flatten the first three dimensions for computation + do = do.reshape(batch_size * nheads_k * group_size, seq_len_q, head_dim) + q = q.reshape(batch_size * nheads_k * group_size, seq_len_q, head_dim) + k = k.reshape(batch_size * nheads_k * group_size, seq_len_k, head_dim) + v = v.reshape(batch_size * nheads_k * group_size, seq_len_k, head_dim) + o = o.reshape(batch_size * nheads_k * group_size, seq_len_q, head_dim) + softmax_lse = softmax_lse.reshape(batch_size * nheads_k * group_size, seq_len_q) + else: + # Standard case + do = do.reshape(batch_size * nheads_q, seq_len_q, head_dim) + q = q.reshape(batch_size * nheads_q, seq_len_q, head_dim) + k = k.reshape(batch_size * nheads_k, seq_len_k, head_dim) + v = v.reshape(batch_size * nheads_k, seq_len_k, head_dim) + o = o.reshape(batch_size * nheads_q, seq_len_q, head_dim) + softmax_lse = softmax_lse.reshape(batch_size * nheads_q, seq_len_q) + + # Call the core backward function dq, dk, dv, delta = attention_backward_core_ref_impl( do, q, @@ -252,14 +365,32 @@ def attention_vanilla_backward_pytorch_ref_impl( softmax_lse, sm_scale, causal, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, use_exp2 ) - # Reshape outputs back to [batch_size, num_heads, seq_len, head_dim] - dq = dq.reshape(batch_size, num_heads, seq_len_q, head_dim) - dk = dk.reshape(batch_size, num_heads, seq_len_k, head_dim) - dv = dv.reshape(batch_size, num_heads, seq_len_k, head_dim) - delta = delta.reshape(batch_size, num_heads, seq_len_q) + if group_size != 1: + # Reshape dq back to [batch_size, nheads_k, group_size, seq_len_q, head_dim] + dq = dq.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) + # Reshape delta back to [batch_size, nheads_k, group_size, seq_len_q] + delta = delta.reshape(batch_size, nheads_k, group_size, seq_len_q) + # Sum dk and dv over group_size dimension, since k and v are shared across groups + dk = dk.reshape(batch_size, nheads_k, group_size, seq_len_k, head_dim) + dk = dk.sum(dim=2) # Sum over group_size dimension + dv = dv.reshape(batch_size, nheads_k, group_size, seq_len_k, head_dim) + dv = dv.sum(dim=2) + # Reshape dq to [batch_size, nheads_q, seq_len_q, head_dim] + dq = dq.reshape(batch_size, nheads_k * group_size, seq_len_q, head_dim) + delta = delta.reshape(batch_size, nheads_k * group_size, seq_len_q) + else: + # Standard case + dq = dq.reshape(batch_size, nheads_q, seq_len_q, head_dim) + dk = dk.reshape(batch_size, nheads_k, seq_len_k, head_dim) + dv = dv.reshape(batch_size, nheads_k, seq_len_k, head_dim) + delta = delta.reshape(batch_size, nheads_q, seq_len_q) # Go back to original layout if layout == "bshd": @@ -276,25 +407,31 @@ def attention_vanilla_backward_pytorch_ref_impl( return dq, dk, dv, delta - def attention_backward_pytorch_ref_impl( - do, - q, - k, - v, - o, - softmax_lse, - sm_scale, - causal, - layout, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - use_exp2 + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + softmax_lse: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + layout: Literal["bshd", "bhsd", "thd"], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + dropout_p: float, + philox_seed: Optional[int], + philox_offset: Optional[int], + use_exp2: bool ): if layout == "thd": - dq, dk, dv, delta = attention_varlen_backward_pytorch_ref_impl( + dq_ref, dk_ref, dv_ref, delta = attention_varlen_backward_pytorch_ref_impl( do, q, k, @@ -308,10 +445,14 @@ def attention_backward_pytorch_ref_impl( cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, use_exp2, ) else: - dq, dk, dv, delta = attention_vanilla_backward_pytorch_ref_impl( + dq_ref, dk_ref, dv_ref, delta = attention_vanilla_backward_pytorch_ref_impl( do, q, k, @@ -321,8 +462,17 @@ def attention_backward_pytorch_ref_impl( sm_scale, causal, layout, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, use_exp2, ) - return dq, dk, dv, delta + # copy into output tensor + dv.copy_(dv_ref.to(dv.dtype)) + dk.copy_(dk_ref.to(dk.dtype)) + dq.copy_(dq_ref.to(dq.dtype)) + + return delta diff --git a/flash_attn/flash_attn_triton_amd/fp8.py b/flash_attn/flash_attn_triton_amd/fp8.py new file mode 100644 index 00000000000..df79c7926b2 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/fp8.py @@ -0,0 +1,716 @@ +from typing import Optional, Sequence, Tuple, Union +import torch +import torch.nn as nn +from .utils import cast_to_fp8, is_fp8 +from . import interface_fa as flash_attn_gpu + + +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + +class FlashAttnFP8Func(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + is_grad_enabled, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_do: Optional[torch.Tensor] = None + ): + is_grad = is_grad_enabled and any( + x.requires_grad for x in [q, k, v] + ) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + head_size_og = q.size(3) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + + # figure out fwd parameters + if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output + raise ValueError("fp8 input and out not supported yet for this function.") + assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" + q_fp8 = q + k_fp8 = k + v_fp8 = v + out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) + else: # cast to fp8 and return output in the fp32. (accumulator type) + assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." + q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "bshd") + k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "bshd") + v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "bshd") + out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None + + q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] + _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( + q_fp8, + k_fp8, + v_fp8, + out_fp8, + alibi_slopes, + dropout_p, + softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + softcap=softcap, + return_softmax=return_softmax and dropout_p > 0, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_o=descale_o + ) + if is_grad: + ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + out = out_fp8[..., :head_size_og] # NOTE: this used to be out_padded. It might cause issue doing an empty + + # check output type + assert out.dtype == q.dtype, "Input and output type must match otherwise there will be implicit casting by autograd" + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors + head_size_og = dout.size(3) + dout_padded = dout + if head_size_og % 8 != 0: + dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) + + # figure out bwd parameters + if is_fp8(dout): # fp8 input and output + raise ValueError("fp8 input and out not supported yet for this function.") + assert (descale_do is not None), f"You need to pass descale factors for do" + dout_padded_fp8 = dout_padded + dq, descale_dq = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) + dk, descale_dk = torch.zeros_like(k_fp8), torch.zeros_like(descale_k) + dv, descale_dv = torch.zeros_like(v_fp8), torch.zeros_like(descale_v) + else: # cast to fp8 and return output in the fp32. (accumulator type) + assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." + dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "bshd") + dq, descale_dq = torch.zeros_like(q_fp8, dtype=torch.float32), None + dk, descale_dk = torch.zeros_like(k_fp8, dtype=torch.float32), None + dv, descale_dv = torch.zeros_like(v_fp8, dtype=torch.float32), None + + # dq, dk, dv are allocated by us so they should already be contiguous + dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] + flash_attn_gpu.bwd( + dout_padded_fp8, + q_fp8, + k_fp8, + v_fp8, + out_fp8, + softmax_lse, + dq, + dk, + dv, + ctx.alibi_slopes, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.window_size[0], + ctx.window_size[1], + ctx.softcap, + ctx.deterministic, + None, # gen_ + rng_state, + descale_q, + descale_k, + descale_v, + descale_o, + descale_do, + descale_dq, + descale_dk, + descale_dv, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dk = dk[..., : dout.shape[-1]] + dv = dv[..., : dout.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + +def flash_attn_fp8_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_do: Optional[torch.Tensor] = None +): + return FlashAttnFP8Func.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + torch.is_grad_enabled(), + descale_q, + descale_k, + descale_v, + descale_do + ) + +class FlashAttnVarlenFP8Func(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + block_table, + is_grad_enabled, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_do: Optional[torch.Tensor] = None + ): + is_grad = is_grad_enabled and any( + x.requires_grad for x in [q, k, v] + ) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + head_size_og = q.size(2) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + + # figure out fwd parameters + if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output + raise ValueError("fp8 input and out not supported yet for this function.") + assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" + q_fp8 = q + k_fp8 = k + v_fp8 = v + out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) + else: # cast to fp8 and return output in the fp32. (accumulator type) + assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." + q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_q, max_seqlen=max_seqlen_q) + k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_k, max_seqlen=max_seqlen_k) + v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_k, max_seqlen=max_seqlen_k) + out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None + + q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] + _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( + q_fp8, + k_fp8, + v_fp8, + out_fp8, + cu_seqlens_q, + cu_seqlens_k, + None, + None, + block_table, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + False, + causal, + window_size[0], + window_size[1], + softcap, + return_softmax, + None, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_o=descale_o + ) + if is_grad: + ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + out = out_fp8[..., :head_size_og] # NOTE: this used to be out_padded. It might cause issue doing an empty + + # check output type + assert out.dtype == q.dtype, "Input and output type must match otherwise there will be implicit casting by autograd" + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors + head_size_og = dout.size(2) + dout_padded = dout + if head_size_og % 8 != 0: + dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) + + # figure out bwd parameters + if is_fp8(dout_padded): # fp8 input and output + raise ValueError("fp8 input and out not supported yet for this function.") + assert (descale_do is not None), f"You need to pass descale factors for do" + dout_padded_fp8 = dout_padded + dq, descale_dq = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) + dk, descale_dk = torch.zeros_like(k_fp8), torch.zeros_like(descale_k) + dv, descale_dv = torch.zeros_like(v_fp8), torch.zeros_like(descale_v) + else: # cast to fp8 and return output in the fp32. (accumulator type) + assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." + dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_q, max_seqlen=ctx.max_seqlen_q) + dq, descale_dq = torch.zeros_like(q_fp8, dtype=torch.float32), None + dk, descale_dk = torch.zeros_like(k_fp8, dtype=torch.float32), None + dv, descale_dv = torch.zeros_like(v_fp8, dtype=torch.float32), None + + # dq, dk, dv are allocated by us so they should already be contiguous + dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] + flash_attn_gpu.varlen_bwd( + dout_padded_fp8, + q_fp8, + k_fp8, + v_fp8, + out_fp8, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + ctx.alibi_slopes, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.softmax_scale, + False, + ctx.causal, + ctx.window_size[0], + ctx.window_size[1], + ctx.softcap, + ctx.deterministic, + None, + rng_state, + descale_q, + descale_k, + descale_v, + descale_o, + descale_do, + descale_dq, + descale_dk, + descale_dv, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dk = dk[..., : dout.shape[-1]] + dv = dv[..., : dout.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + + +def flash_attn_varlen_fp8_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + block_table=None +): + return FlashAttnVarlenFP8Func.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + block_table, + torch.is_grad_enabled() + ) + +class FlashAttnQKVPackedFP8Func(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + is_grad_enabled, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_do: Optional[torch.Tensor] = None + ): + is_grad = is_grad_enabled and qkv.requires_grad + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach() + head_size_og = q.size(3) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + + # figure out fwd parameters + if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output + raise ValueError("fp8 input and out not supported yet for this function.") + assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" + q_fp8 = q + k_fp8 = k + v_fp8 = v + out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) + else: # cast to fp8 and return output in the fp32. (accumulator type) + assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." + q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "bshd") + k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "bshd") + v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "bshd") + out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None + + q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] + _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( + q_fp8, + k_fp8, + v_fp8, + out_fp8, + alibi_slopes, + dropout_p, + softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + softcap=softcap, + return_softmax=return_softmax and dropout_p > 0, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_o=descale_o, + ) + if is_grad: + ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + out = out_fp8[..., :head_size_og] + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors + qkv_shape = q_fp8.shape[:-2] + (3, *q_fp8.shape[-2:]) + head_size_og = dout.size(3) + dout_padded = dout + if head_size_og % 8 != 0: + dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) + + # figure out bwd parameters + if is_fp8(dout): # fp8 input and output + raise ValueError("fp8 input and out not supported yet for this function.") + assert (descale_do is not None), f"You need to pass descale factors for do" + dout_padded_fp8 = dout_padded + dqkv, descale_dqkv = torch.zeros(qkv_shape, device=q_fp8.device), torch.zeros_like(descale_q) + else: # cast to fp8 and return output in the fp32. (accumulator type) + assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." + dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "bshd") + dqkv, descale_dqkv = torch.zeros(qkv_shape, dtype=torch.float32, device=q_fp8.device), None + + + # dq, dk, dv are allocated by us so they should already be contiguous + dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] + flash_attn_gpu.bwd( + dout_padded_fp8, + q_fp8, + k_fp8, + v_fp8, + out_fp8, + softmax_lse, + dqkv[:, :, 0], + dqkv[:, :, 1], + dqkv[:, :, 2], + ctx.alibi_slopes, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.window_size[0], + ctx.window_size[1], + ctx.softcap, + ctx.deterministic, + None, # gen_ + rng_state, + descale_q, + descale_k, + descale_v, + descale_o, + descale_do, + None, + None, + None, + ) + dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension + return dqkv, None, None, None, None, None, None, None, None, None + + +def flash_attn_qkvpacked_fp8_func( + qkv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # <=0.0 means deactivate + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, +): + return FlashAttnQKVPackedFP8Func.apply( + qkv, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + torch.is_grad_enabled(), + ) + + +class FlashAttnVarlenQKVPackedFP8Func(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + is_grad_enabled, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_do: Optional[torch.Tensor] = None + ): + is_grad = is_grad_enabled and qkv.requires_grad + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach() + head_size_og = q.size(2) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + + # figure out fwd parameters + if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output + raise ValueError("fp8 input and out not supported yet for this function.") + assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" + q_fp8 = q + k_fp8 = k + v_fp8 = v + out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) + else: # cast to fp8 and return output in the fp32. (accumulator type) + assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." + q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None + + q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] + _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( + q_fp8, + k_fp8, + v_fp8, + out_fp8, + cu_seqlens, + cu_seqlens, + None, + None, + None, + alibi_slopes, + max_seqlen, + max_seqlen, + dropout_p, + softmax_scale, + False, + causal, + window_size[0], + window_size[1], + softcap, + return_softmax, + None, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_o=descale_o + ) + if is_grad: + ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) + ctx.dropout_p = dropout_p + ctx.max_seqlen = max_seqlen + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + out = out_fp8[..., :head_size_og] + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors + qkv_shape = q_fp8.shape[:-2] + (3, *q_fp8.shape[-2:]) + head_size_og = dout.size(2) + dout_padded = dout + if head_size_og % 8 != 0: + dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) + + # figure out bwd parameters + if is_fp8(dout_padded): # fp8 input and output + raise ValueError("fp8 input and out not supported yet for this function.") + assert (descale_do is not None), f"You need to pass descale factors for do" + dout_padded_fp8 = dout_padded + dqkv, descale_dqkv = torch.zeros(qkv_shape, device=q_fp8.device), torch.zeros_like(descale_q) + else: # cast to fp8 and return output in the fp32. (accumulator type) + assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." + dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=ctx.max_seqlen) + dqkv, descale_dqkv = torch.zeros(qkv_shape, dtype=torch.float32, device=q_fp8.device), None + + # dq, dk, dv are allocated by us so they should already be contiguous + dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] + flash_attn_gpu.varlen_bwd( + dout_padded_fp8, + q_fp8, + k_fp8, + v_fp8, + out_fp8, + softmax_lse, + dqkv[:, 0], + dqkv[:, 1], + dqkv[:, 2], + cu_seqlens, + cu_seqlens, + ctx.alibi_slopes, + ctx.max_seqlen, + ctx.max_seqlen, + ctx.dropout_p, + ctx.softmax_scale, + False, + ctx.causal, + ctx.window_size[0], + ctx.window_size[1], + ctx.softcap, + ctx.deterministic, + None, + rng_state, + descale_q, + descale_k, + descale_v, + descale_o, + descale_do, + None, + None, + None, + ) + dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension + return dqkv, None, None, None, None, None, None, None, None, None, None, None + + +def flash_attn_varlen_qkvpacked_fp8_func( + qkv, + cu_seqlens, + max_seqlen, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, +): + return FlashAttnVarlenQKVPackedFP8Func.apply( + qkv, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + torch.is_grad_enabled(), + ) diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index b37308be491..3f2d92c22d6 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -1,16 +1,75 @@ import torch import triton import triton.language as tl -from .utils import _strides, get_padded_headsize - +from typing import Literal, Optional, Union +from .utils import AUTOTUNE, DEBUG, get_padded_headsize, get_shape_and_strides_from_layout, is_cdna + +def get_cdna_autotune_configs(): + return [ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + # Fall-back config. + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] + +def get_autotune_configs(): + if AUTOTUNE: + if is_cdna(): + autotune_configs, autotune_keys = get_cdna_autotune_configs() + fwd_auto_tune_configs, fwd_autotune_keys= autotune_configs, autotune_keys + reduce_auto_tune_configs, reduce_autotune_keys = autotune_configs, autotune_keys + return (fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) + else: + raise ValueError("Unknown Device Type") + else: + autotune_configs, autotune_keys = [ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + ], [ + "IS_CAUSAL", + "dropout_p", + "MAX_SEQLENS_Q", + "MAX_SEQLENS_K", + "ACTUAL_BLOCK_DMODEL", + "VARLEN", + "HQ", + "HK", + ] + + fwd_auto_tune_configs, fwd_autotune_keys= autotune_configs, autotune_keys + reduce_auto_tune_configs, reduce_autotune_keys = autotune_configs, autotune_keys + return (fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) + + +(fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) = get_autotune_configs() + +# @triton.autotune( +# configs=fwd_auto_tune_configs, +# key=fwd_autotune_keys, +# use_cuda_graph=True, +# ) @triton.jit def _fwd_kernel_splitK( Q, K, V, sm_scale, - Out_splitK, # [B, H, split_k, Mq, K] - Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Out_splitK, # [B*H*G, split_k, Mq, K] + Metadata, # [B*H*G, 2, split_k, M_ceil] contains [mi, li] K_new, V_new, Cache_seqlens, @@ -70,62 +129,91 @@ def _fwd_kernel_splitK( IS_GQA: tl.constexpr, IS_CAUSAL: tl.constexpr, USE_ALIBI: tl.constexpr, + PADDED_HEAD: tl.constexpr, + GROUP_SIZE: tl.constexpr, ): - # Padding - PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) - if PADDED_HEAD: - d_mask = tl.arange(0, BLOCK_DMODEL) < ACTUAL_BLOCK_DMODEL - - start_m = tl.program_id(0) - off_zhg = tl.program_id(1) - off_z = off_zhg // (H_q * G_q) - off_h_q = (off_zhg // G_q) % H_q - off_g_q = off_zhg % G_q - splitk_idx = tl.program_id(2) + # get program ids + pid_m = tl.program_id(0) + pid_zhg = tl.program_id(1) + pid_splitk = tl.program_id(2) - # pick batch index - if USE_CACHE_BATCH_IDX: - cache_batch_idx = tl.load(Cache_batch_idx + off_z) - else: - cache_batch_idx = off_z + # compute z, h and g ids + z_id = pid_zhg // (H_q * G_q) + hq_id = (pid_zhg // G_q) % H_q + g_id = pid_zhg % G_q - # Load ALiBi slope if enabled - if USE_ALIBI: - a_offset = off_z * stride_az + off_h_q * stride_ah - alibi_slope = tl.load(Alibi_slopes + a_offset) + # is gqa + if IS_GQA: + hk_id = hq_id // GROUP_SIZE + hv_id = hk_id else: - alibi_slope = None + hk_id = hq_id + hv_id = hq_id - lo = splitk_idx * BLOCK_N_PER_SPLIT + # figure out seqlens + lo = pid_splitk * BLOCK_N_PER_SPLIT if USE_CACHE_SEQLENs: - cache_seqlen_last_idx = tl.load(Cache_seqlens + off_z) + cache_seqlen_last_idx = tl.load(Cache_seqlens + z_id) if NEW_KV: - kv_len = cache_seqlen_last_idx + N_CTX_NEW + N_CTX_K_FINAL = cache_seqlen_last_idx + N_CTX_NEW else: - kv_len = cache_seqlen_last_idx + N_CTX_K_FINAL = cache_seqlen_last_idx else: - kv_len = N_CTX_K - hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len) + N_CTX_K_FINAL = N_CTX_K + hi = tl.minimum((pid_splitk + 1) * BLOCK_N_PER_SPLIT, N_CTX_K_FINAL) - HEAD_RATIO: tl.constexpr = H_q // H_kv - if IS_GQA: - k_head_idx = off_h_q // HEAD_RATIO - v_head_idx = k_head_idx + # pick batch index + if USE_CACHE_BATCH_IDX: + cache_batch_idx = tl.load(Cache_batch_idx + z_id) + else: + cache_batch_idx = z_id + + # compute offsets + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + # compute ptrs + q_offset = Q + hq_id * stride_qh + z_id * stride_qz + g_id * stride_qg + q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + k_offset = K + hk_id * stride_kh + cache_batch_idx * stride_kz + g_id * stride_kg + v_offset = V + hv_id * stride_vh + cache_batch_idx * stride_vz + g_id * stride_vg + + # compute masks + if PADDED_HEAD: + q_mask = (offs_m < N_CTX_Q)[:, None] & (offs_d < ACTUAL_BLOCK_DMODEL)[None, :] + kT_mask = (offs_d < ACTUAL_BLOCK_DMODEL)[:, None] & (offs_n < N_CTX_K_FINAL)[None, :] + v_mask = (offs_n < N_CTX_K_FINAL)[:, None] & (offs_d < ACTUAL_BLOCK_DMODEL)[None, :] + osk_mask = (offs_m < N_CTX_Q)[:, None] & (offs_d < ACTUAL_BLOCK_DMODEL)[None, :] else: - k_head_idx = off_h_q - v_head_idx = off_h_q + q_mask = (offs_m < N_CTX_Q)[:, None] + kT_mask = (offs_n < N_CTX_K_FINAL)[None, :] + v_mask = (offs_n < N_CTX_K_FINAL)[:, None] + osk_mask = (offs_m < N_CTX_Q)[:, None] - # calculate base offset - k_base = K + k_head_idx * stride_kh + cache_batch_idx * stride_kz + off_g_q * stride_kg - v_base = V + v_head_idx * stride_vh + cache_batch_idx * stride_vz + off_g_q * stride_vg + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs, mask=q_mask, other=0.0) + q = (q * qk_scale).to(q.dtype) + + # load ALiBi slope if enabled + if USE_ALIBI: + a_offset = z_id * stride_az + hq_id * stride_ah + alibi_slope = tl.load(Alibi_slopes + a_offset) + else: + alibi_slope = None # Copy new Keys and Values into Cache if NEW_KV: - knew_base = K_new + k_head_idx * stride_kn_h + off_z * stride_kn_z + off_g_q * stride_kn_g + knew_base = K_new + hk_id * stride_kn_h + z_id * stride_kn_z + g_id * stride_kn_g # Determine the starting position for new data in the cache if USE_CACHE_SEQLENs: - start_idx = tl.load(Cache_seqlens + off_z) + start_idx = tl.load(Cache_seqlens + z_id) else: start_idx = N_CTX_K - N_CTX_NEW @@ -143,7 +231,7 @@ def _fwd_kernel_splitK( # Store to K tl.store( - k_base + + k_offset + tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kd + (tl.arange(0, BLOCK_N) + i + start_idx)[None, :] * stride_kn, k_new_block, @@ -152,7 +240,7 @@ def _fwd_kernel_splitK( ) # Copy new Values - vnew_base = V_new + v_head_idx * stride_vn_h + off_z * stride_vn_z + off_g_q * stride_vn_g + vnew_base = V_new + hv_id * stride_vn_h + z_id * stride_vn_z + g_id * stride_vn_g for i in range(0, N_CTX_NEW, BLOCK_N): # Load from V_new v_new_block = tl.load( @@ -166,7 +254,7 @@ def _fwd_kernel_splitK( # Store to V tl.store( - v_base + + v_offset + (tl.arange(0, BLOCK_N) + i + start_idx)[:, None] * stride_vn + tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vd, v_new_block, @@ -174,34 +262,6 @@ def _fwd_kernel_splitK( (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL), ) - Q_block_ptr = tl.make_block_ptr( - base=Q + off_h_q * stride_qh + off_z * stride_qz + off_g_q * stride_qg, - shape=(N_CTX_Q, ACTUAL_BLOCK_DMODEL), - strides=(stride_qm, stride_qd), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - - K_block_ptr = tl.make_block_ptr( - base=k_base, - shape=(ACTUAL_BLOCK_DMODEL, hi), - strides=(stride_kd, stride_kn), - offsets=(0, lo), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1), - ) - V_block_ptr = tl.make_block_ptr( - base=v_base, - shape=(hi, ACTUAL_BLOCK_DMODEL), - strides=(stride_vn, stride_vd), - offsets=(lo, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0), - ) - - K_scale_shift_block_ptr = None - V_scale_shift_block_ptr = None # initialize pointer to m and l m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) @@ -209,45 +269,26 @@ def _fwd_kernel_splitK( acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # noqa: F821 - # scale sm_scale by log_2(e) and use - # 2^x instead of exp in the loop because CSE and LICM - # don't work as expected with `exp` in the loop - qk_scale = sm_scale * 1.44269504 - # load q: it will stay in SRAM throughout - q = tl.load( # noqa: F821 - tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, )) - q = (q * qk_scale).to(q.dtype) - if PADDED_HEAD: - q = tl.where(d_mask[None, :], q, 0.0) # loop over k, v and update accumulator for start_n in range(lo, hi, BLOCK_N): - k, v = load_k_v_group( - K_block_ptr, - V_block_ptr, - K_scale_shift_block_ptr, - V_scale_shift_block_ptr, - BOUNDS_CHECKS_N, - 1, - BLOCK_DMODEL, - ACTUAL_BLOCK_DMODEL, - Q.dtype.element_ty, - 0, - ) - if PADDED_HEAD: - k = tl.where(d_mask[:, None], k, 0.0) - v = tl.where(d_mask[None, :], v, 0.0) + kT_ptrs = k_offset + offs_d[:, None] * stride_kd + (start_n + offs_n)[None, :] * stride_kn + V_ptrs = v_offset + (start_n + offs_n)[:, None] * stride_vn + offs_d[None, :] * stride_vd + + # load k + kT = tl.load(kT_ptrs, mask=kT_mask, other=0.0) + v = tl.load(V_ptrs, mask=v_mask, other=0.0) # -- compute qk --- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) # noqa: F821 + qk += tl.dot(q, kT) # noqa: F821 if USE_ALIBI: - row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) col_idx = start_n + tl.arange(0, BLOCK_N) # Compute relative positions - relative_pos = row_idx[:, None] + kv_len - (N_CTX_Q + col_idx[None, :]) + relative_pos = row_idx[:, None] + N_CTX_K_FINAL - (N_CTX_Q + col_idx[None, :]) relative_pos = tl.abs(relative_pos) # Compute ALiBi bias @@ -256,11 +297,11 @@ def _fwd_kernel_splitK( # Apply causal mask if IS_CAUSAL is True if IS_CAUSAL: - row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) col_idx = start_n + tl.arange(0, BLOCK_N) # create a N_CTX_Q x kv_len causal mask - col_offset = N_CTX_Q - kv_len + col_offset = N_CTX_Q - N_CTX_K_FINAL causal_mask = row_idx[:, None] >= (col_offset + col_idx[None, :]) # Apply the mask @@ -293,101 +334,34 @@ def _fwd_kernel_splitK( # -- scale and update acc -- acc *= alpha[:, None] acc += tl.dot(p.to(v.dtype), v) - - # update pointers - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) # write back O - O_block_ptr = tl.make_block_ptr( - base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s, - shape=(N_CTX_Q, BLOCK_DMODEL), - strides=(stride_osk_m, 1), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) + osk_offset = Out_splitK + pid_zhg * stride_osk_zhg + pid_splitk * stride_osk_s + osk_ptrs = osk_offset + offs_m[:, None] * stride_osk_m + offs_d[None, :] * stride_osk_d tl.store( - tl.advance(O_block_ptr, (0, 0)), + osk_ptrs, acc, - boundary_check=(0, ), + mask=osk_mask, ) - # Write metadata for split-K reduction - Metadata_ptr = (Metadata + off_zhg * stride_mzhg + splitk_idx * stride_ms + start_m * BLOCK_M + - tl.arange(0, BLOCK_M)) - tl.store(Metadata_ptr, m_i) - tl.store(Metadata_ptr + stride_m2, l_i) - - -@triton.jit -def load_k_v_group( - K_block_ptr, - V_block_ptr, - K_scale_shift_block_ptr, - V_scale_shift_block_ptr, - BOUNDS_CHECKS_N: tl.constexpr, - PACKED_PER_VAL: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, - dtype: tl.constexpr, - group_id: tl.constexpr, -): - #Load K/V for a given block - - # Advance to the current quantization group - K_block_ptr = tl.advance(K_block_ptr, (ACTUAL_BLOCK_DMODEL * group_id, 0)) - V_block_ptr = tl.advance(V_block_ptr, (0, ACTUAL_BLOCK_DMODEL * group_id)) - - # -- load k, v -- - k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ()) - v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ()) - - return k, v - - -@triton.jit -def cast_uint32_to_half2(scale_shift): - # Extract two float16 packed into one int32 - scale = scale_shift & 0xFFFF - shift = scale_shift >> 16 - scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) - shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) - return scale, shift - - -@triton.jit -def dequantize( - x_, - scale, - shift, - PACKED_PER_VAL: tl.constexpr = 8, -): - # PACKED_PER_VAL is the number of values packed into - # each element x_. For example, for int4 quantization - #and x_ of type int32, PACKED_PER_VAL is 8. - BLOCK_N: tl.constexpr = x_.shape[0] - BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] - offsets = tl.arange(0, PACKED_PER_VAL) * 4 - quant_offset = (x_[:, None, :] >> offsets[None, :, None]) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) - - quant_offset = tl.view(quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)) - # Trick - instead of converting int4 to float16 we view it as float16 - # and then multiply by 32768 * 512 == 2**24 - quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) - quant_offset = (quant_offset * 32768.0).to(tl.float16) - scale_512 = scale * 512 - - dequant = quant_offset * scale_512 + shift - return dequant + # write metadata for split-K reduction + metadata_offset = Metadata + pid_zhg * stride_mzhg + pid_splitk * stride_ms + metadata_ptr = metadata_offset + offs_m + tl.store(metadata_ptr, m_i) + tl.store(metadata_ptr + stride_m2, l_i) +# @triton.autotune( +# configs=reduce_auto_tune_configs, +# key=reduce_autotune_keys, +# use_cuda_graph=True, +# ) @triton.jit def _splitK_reduce( - Out_splitK, # [B, H, split_k, Mq, K] - Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] - Out, # [B, H, M, K] - LSE, # [B, H, M] + Out_splitK, # [B*H*G, split_k, Mq, K] + Metadata, # [B*H*G, 2, split_k, M_ceil] contains [mi, li] + Out, # [B, H, G, M, K] + LSE, # [B*H*G, M] stride_osk_zhg, stride_osk_s, stride_osk_m, @@ -403,41 +377,50 @@ def _splitK_reduce( stride_ok, stride_lse_zhg, stride_lse_m, - M_ceil: tl.constexpr, - BLOCK_SIZE: tl.constexpr, + K_BLOCK_SIZE: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, H: tl.constexpr, G: tl.constexpr, split_k: tl.constexpr, splitK_pow2: tl.constexpr, - use_mask: tl.constexpr, + MASK_SPLITK: tl.constexpr, IS_CAUSAL: tl.constexpr, + PADDED_HEAD: tl.constexpr, ): - off_zhg = tl.program_id(0) - off_z = off_zhg // (H * G) - off_h = (off_zhg // G) % H - off_g = off_zhg % G - off_m = tl.program_id(1) - off_k = tl.program_id(2) + # get pids + pid_zhg = tl.program_id(0) + pid_m = tl.program_id(1) + pid_k = tl.program_id(2) - # read chunk - spk_idx = tl.arange(0, splitK_pow2) - kidx = tl.arange(0, BLOCK_SIZE) + # compute offsets + offs_splitK = tl.arange(0, splitK_pow2) + offs_k = pid_k * K_BLOCK_SIZE + tl.arange(0, K_BLOCK_SIZE) - Metadata_ptr = (Metadata + stride_mzhg * off_zhg + spk_idx * stride_ms + off_m * stride_mm) - o_ptr = (Out_splitK + off_zhg * stride_osk_zhg + stride_osk_m * off_m + off_k * BLOCK_SIZE + - stride_osk_s * spk_idx[:, None] + kidx[None, :] * stride_osk_k) + # compute masks + if PADDED_HEAD: + o_mask = offs_k < ACTUAL_BLOCK_DMODEL + else: + o_mask = None + + # compute ptrs + metadata_offset = Metadata + pid_zhg * stride_mzhg + metadata_ptr = metadata_offset + offs_splitK * stride_ms + pid_m * stride_mm + + osk_offset = Out_splitK + pid_zhg * stride_osk_zhg + pid_m * stride_osk_m + osk_ptr = osk_offset + offs_splitK[:, None] * stride_osk_s + offs_k[None, :] * stride_osk_k # read max values of each splitK - if use_mask: - spk_mask = spk_idx < split_k - l_m = tl.load(Metadata_ptr, mask=spk_mask, other=float("-inf")) - l_sum = tl.load(Metadata_ptr + stride_m2, mask=spk_mask, other=0.0) - acc = tl.load(o_ptr, mask=spk_mask[:, None], other=0.0) + if MASK_SPLITK: + splitK_mask = offs_splitK < split_k + l_m = tl.load(metadata_ptr, mask=splitK_mask, other=float("-inf")) + l_sum = tl.load(metadata_ptr + stride_m2, mask=splitK_mask, other=0.0) + acc = tl.load(osk_ptr, mask=splitK_mask[:, None], other=0.0) else: - l_m = tl.load(Metadata_ptr) - l_sum = tl.load(Metadata_ptr + stride_m2) - acc = tl.load(o_ptr) + l_m = tl.load(metadata_ptr) + l_sum = tl.load(metadata_ptr + stride_m2) + acc = tl.load(osk_ptr) g_m = tl.max(l_m, axis=0) @@ -460,12 +443,15 @@ def _splitK_reduce( acc_out = tl.sum(acc, axis=0) / g_sum # Store output - Out_ptr = (Out + stride_oz * off_z + stride_oh * off_h + stride_og * off_g + stride_om * off_m + - off_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)) - tl.store(Out_ptr, acc_out) + z_id = pid_zhg // (H * G) + h_id = (pid_zhg // G) % H + g_id = pid_zhg % G + out_offset = Out + z_id * stride_oz + h_id * stride_oh + g_id * stride_og + out_ptr = out_offset + pid_m * stride_om + offs_k + tl.store(out_ptr, acc_out, mask=o_mask) # Store lse - l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m + l_ptrs = LSE + pid_zhg * stride_lse_zhg + pid_m if IS_CAUSAL: lse = tl.where(g_sum > 0, (g_m + tl.math.log2(g_sum)) / 1.44269504, g_m) tl.store(l_ptrs, lse) @@ -473,6 +459,41 @@ def _splitK_reduce( tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504) +@triton.jit +def cast_uint32_to_half2(scale_shift): + # Extract two float16 packed into one int32 + scale = scale_shift & 0xFFFF + shift = scale_shift >> 16 + scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) + shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) + return scale, shift + +@triton.jit +def dequantize( + x_, + scale, + shift, + PACKED_PER_VAL: tl.constexpr = 8, +): + # PACKED_PER_VAL is the number of values packed into + # each element x_. For example, for int4 quantization + #and x_ of type int32, PACKED_PER_VAL is 8. + + BLOCK_N: tl.constexpr = x_.shape[0] + BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] + offsets = tl.arange(0, PACKED_PER_VAL) * 4 + quant_offset = (x_[:, None, :] >> offsets[None, :, None]) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) + + quant_offset = tl.view(quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)) + # Trick - instead of converting int4 to float16 we view it as float16 + # and then multiply by 32768 * 512 == 2**24 + quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) + quant_offset = (quant_offset * 32768.0).to(tl.float16) + scale_512 = scale * 512 + + dequant = quant_offset * scale_512 + shift + return dequant + def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: # Scale and shift are such that quantization linearly maps # int4 values range [0..15] to input values range min(k)..max(k) @@ -540,122 +561,204 @@ def get_split_k(B: int, G: int, H: int, Mk: int) -> int: split_k = max(split_k, 1) return split_k -def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes, layout, cache_seqlens, cache_batch_idx, new_kv, k_new, v_new): - # kernel config +def attention_decode_forward_triton_impl( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_new: Optional[torch.Tensor], + v_new: Optional[torch.Tensor], + out: torch.Tensor, + sm_scale: float, + causal: bool, + alibi_slopes: Optional[torch.Tensor], + layout: Literal["bshd"], + cache_seqlens: Optional[Union[(int, torch.Tensor)]], + cache_batch_idx: Optional[torch.Tensor], +): + # triton configs BLOCK_M = 16 BLOCK_N = 64 + num_stages = 1 + num_warps_fwd = 1 + num_warps_reduce = 4 + + # kernel_configs + is_new_kv = True if k_new is not None and v_new is not None else False + use_alibi = False if alibi_slopes is None else True + use_cache_seqlens = cache_seqlens is not None SPLIT_K = None NUM_QUANT_GROUPS = 1 - # kernels expects "bsghd" - original_layout = layout + # get shapes and strides + (batch_size, seqlen_q, nheads_q, dim_q), (stride_qz, stride_qh, stride_qm, stride_qd) = get_shape_and_strides_from_layout(q, layout) + (_, seqlen_kc, nheads_kc, dim_kc), (stride_kc_z, stride_kc_h, stride_kc_n, stride_kc_d) = get_shape_and_strides_from_layout(k_cache, layout) + (_, seqlen_vc, nheads_vc, dim_vc), (stride_vc_z, stride_vc_h, stride_vc_n, stride_vc_d) = get_shape_and_strides_from_layout(v_cache, layout) + if is_new_kv: + ( _, seqlen_kn, nheads_kn, dim_kn), (stride_kn_z, stride_kn_h, stride_kn_n, stride_kn_d) = get_shape_and_strides_from_layout(k_new, layout) + (_, seqlen_vn, nheads_vn, dim_vn), (stride_vn_z, stride_vn_h, stride_vn_n, stride_vn_d) = get_shape_and_strides_from_layout(v_new, layout) + else: + ( _, seqlen_kn, nheads_kn, dim_kn), (stride_kn_z, stride_kn_h, stride_kn_n, stride_kn_d) = (None, None, None, None), (None, None, None, None) + (_, seqlen_vn, nheads_vn, dim_vn), (stride_vn_z, stride_vn_h, stride_vn_n, stride_vn_d) = (None, None, None, None), (None, None, None, None) + (_, seqlen_o, nheads_o, dim_o), (stride_oz, stride_oh, stride_om, stride_od) = get_shape_and_strides_from_layout(out, layout) + if use_alibi: + stride_az, stride_ah = alibi_slopes.stride() + else: + stride_az, stride_ah = (None, None) + + assert dim_q == dim_kc == dim_vc, f"Dimensions must match: {dim_q}, {dim_kc}, {dim_vc}" + + # add extra information needed by the kernels if layout == "bshd": - q=q.unsqueeze(2) - k=k.unsqueeze(2) - v=v.unsqueeze(2) - if new_kv: - k_new = k_new.unsqueeze(2) - v_new = v_new.unsqueeze(2) - layout = "bsghd" - elif layout == "bhsd": - q=q.permute(0, 2, 1, 3).unsqueeze(2) - k=k.permute(0, 2, 1, 3).unsqueeze(2) - v=v.permute(0, 2, 1, 3).unsqueeze(2) - if new_kv: - k_new = k_new.permute(0, 2, 1, 3).unsqueeze(2) - v_new = v_new.permute(0, 2, 1, 3).unsqueeze(2) - layout = "bsghd" - elif layout == "bsghd": - pass - elif layout is None: - raise ValueError("Layout not given") - assert layout == "bsghd" - - # get dims - batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_q = q.shape - _, seqlen_k, n_group_k, heads_per_group_k, dim_k = k.shape - _, seqlen_v, n_group_v, heads_per_group_v, dim_v = v.shape - - assert dim_q == dim_k == dim_v, f"Dimensions must match: {dim_q}, {dim_k}, {dim_v}" + (n_group_q, heads_per_group_q), stride_qg = (1, nheads_q), stride_qm + (n_group_k, heads_per_group_k), stride_kc_g = (1, nheads_kc), stride_kc_n + (n_group_v, heads_per_group_v), stride_vc_g = (1, nheads_vc), stride_vc_n + if is_new_kv: + (n_group_kn, heads_per_group_kn), stride_kn_g = (1, nheads_kn), stride_kn_n + (n_group_vn, heads_per_group_vn), stride_vn_g = (1, nheads_vn), stride_vn_n + else: + (n_group_kn, heads_per_group_kn), stride_kn_g = (None, None), None + (n_group_vn, heads_per_group_vn), stride_vn_g = (None, None), None + (n_group_o, heads_per_group_o), stride_og = (1, nheads_o), stride_om + else: + raise ValueError(f"{layout} layout is not supported") # get padded size - dim_padded = get_padded_headsize(dim_k) + dim_padded = get_padded_headsize(dim_kc) + is_padded_head = dim_padded != dim_kc # Handle MQA/GQA case - if heads_per_group_q > heads_per_group_k: + group_size = nheads_q // nheads_kc + if group_size > 1: is_gqa = True - elif heads_per_group_q < heads_per_group_k: - raise ValueError("heads_per_group_q < heads_per_group_k") else: is_gqa = False - assert dim_k == dim_q, f"Keys have head dim {dim_k} but queries have head dim {dim_q}" - if SPLIT_K is not None: split_k = SPLIT_K else: # Use heuristics - split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_k) # NOTE: should the split think about seqlens? + split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_kc) # NOTE: should the split think about seqlens? + split_size = (seqlen_kc + split_k - 1) // split_k + # setup grid seqlen_q_ceil = (seqlen_q + BLOCK_M - 1) // BLOCK_M * BLOCK_M - out_splitk = torch.empty([batch_size * n_group_q * heads_per_group_q, split_k, seqlen_q_ceil, dim_padded], dtype=torch.float32, device=q.device) + grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch_size * n_group_q * heads_per_group_q, split_k) + + # create intermediate tensors + out_splitk = torch.empty([batch_size * n_group_q * heads_per_group_q, split_k, seqlen_q_ceil, dim_kc], dtype=torch.float32, device=q.device) metadata = torch.empty([batch_size * n_group_q * heads_per_group_q, 2, split_k, seqlen_q_ceil], dtype=torch.float32, device=q.device) - lse = torch.empty((batch_size * n_group_q * heads_per_group_q, seqlen_q), device=q.device, dtype=torch.float32) - grid = (triton.cdiv(seqlen_q, BLOCK_M), batch_size * n_group_q * heads_per_group_q, split_k) - - num_warps = 1 - split_size = (seqlen_k + split_k - 1) // split_k - use_cache_seqlens = cache_seqlens is not None + lse = torch.empty((batch_size * n_group_q * heads_per_group_q, seqlen_q), dtype=torch.float32, device=q.device) + + # get intermediate tensor strides + stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d = out_splitk.stride() + stride_mzhg, stride_m2, stride_ms, stride_mm = metadata.stride() + stride_lse_zhg, stride_lse_m = lse.stride() + + if False: + print("batch_size, seqlen_q, nheads_q, dim_q", (batch_size, seqlen_q, nheads_q, dim_q)) + print("_, seqlen_kc, nheads_kc, dim_kc", (_, seqlen_kc, nheads_kc, dim_kc)) + print("dim_padded:", dim_padded) + print("stride_qz, stride_qm, stride_qg, stride_qh, stride_qd", (stride_qz, stride_qm, stride_qg, stride_qh, stride_qd)) + print("stride_kc_z, stride_kc_n, stride_kc_g, stride_kc_h, stride_kc_d", (stride_kc_z, stride_kc_n, stride_kc_g, stride_kc_h, stride_kc_d)) + print("stride_vc_z, stride_vc_n, stride_vc_g, stride_vc_h, stride_vc_d", (stride_vc_z, stride_vc_n, stride_vc_g, stride_vc_h, stride_vc_d)) + if is_new_kv: + print("stride_kn_z, stride_kn_n, stride_kn_g, stride_kn_h, stride_kn_d", (stride_kn_z, stride_kn_n, stride_kn_g, stride_kn_h, stride_kn_d)) + print("stride_vn_z, stride_vn_n, stride_vn_g, stride_vn_h, stride_vn_d", (stride_vn_z, stride_vn_n, stride_vn_g, stride_vn_h, stride_vn_d)) + print("stride_oz, stride_om, stride_og, stride_oh, stride_od", (stride_oz, stride_om, stride_og, stride_oh, stride_od)) + print("stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d", (stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d)) + print("stride_mzhg, stride_m2, stride_ms, stride_mm", (stride_mzhg, stride_m2, stride_ms, stride_mm)) + print("stride_lse_zhg, stride_lse_m", (stride_lse_zhg, stride_lse_m)) # TODO: enable quantization _fwd_kernel_splitK[grid]( Q=q, - K=k, - V=v, + K=k_cache, + V=v_cache, sm_scale=sm_scale, Out_splitK=out_splitk, Metadata=metadata, - K_new = k_new, - V_new = v_new, + K_new=k_new, + V_new=v_new, Cache_seqlens=cache_seqlens, Cache_batch_idx=cache_batch_idx, Alibi_slopes=alibi_slopes, - **_strides(q, "qz", "qm", "qg", "qh", "qd"), - **_strides(k, "kz", "kn", "kg", "kh", "kd"), - **_strides(v, "vz", "vn", "vg", "vh", "vd"), - **_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_d"), - **_strides(metadata, "mzhg", "m2", "ms", "mm"), - **_strides(k_new, "kn_z", "kn_n", "kn_g", "kn_h", "kn_d"), - **_strides(v_new, "vn_z", "vn_n", "vn_g", "vn_h", "vn_d"), - **_strides(alibi_slopes, "az", "ah"), + # q strides + stride_qz=stride_qz, + stride_qm=stride_qm, + stride_qg=stride_qg, + stride_qh=stride_qh, + stride_qd=stride_qd, + # k strides + stride_kz=stride_kc_z, + stride_kn=stride_kc_n, + stride_kg=stride_kc_g, + stride_kh=stride_kc_h, + stride_kd=stride_kc_d, + # v strides + stride_vz=stride_vc_z, + stride_vn=stride_vc_n, + stride_vg=stride_vc_g, + stride_vh=stride_vc_h, + stride_vd=stride_vc_d, + # out_splitk strides + stride_osk_zhg=stride_osk_zhg, + stride_osk_s=stride_osk_s, + stride_osk_m=stride_osk_m, + stride_osk_d=stride_osk_d, + # metadata strides + stride_mzhg=stride_mzhg, + stride_m2=stride_m2, + stride_ms=stride_ms, + stride_mm=stride_mm, + # k_new strides + stride_kn_z=stride_kn_z, + stride_kn_n=stride_kn_n, + stride_kn_g=stride_kn_g, + stride_kn_h=stride_kn_h, + stride_kn_d=stride_kn_d, + # v_new strides + stride_vn_z=stride_vn_z, + stride_vn_n=stride_vn_n, + stride_vn_g=stride_vn_g, + stride_vn_h=stride_vn_h, + stride_vn_d=stride_vn_d, + # alibi strides + stride_az=stride_az, + stride_ah=stride_ah, Z=batch_size, H_q=heads_per_group_q, H_kv=heads_per_group_k, G_q=n_group_q, N_CTX_Q=seqlen_q, - N_CTX_K=seqlen_k, - N_CTX_NEW=k_new.shape[1] if new_kv else None, + N_CTX_K=seqlen_kc, + N_CTX_NEW=seqlen_kn, BLOCK_N_PER_SPLIT=split_size, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=dim_padded, - ACTUAL_BLOCK_DMODEL=dim_k, + ACTUAL_BLOCK_DMODEL=dim_kc, BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens, USE_CACHE_SEQLENs=use_cache_seqlens, USE_CACHE_BATCH_IDX=cache_batch_idx is not None, - NEW_KV=new_kv, + NEW_KV=is_new_kv, IS_GQA=is_gqa, IS_CAUSAL=causal, - USE_ALIBI=False if alibi_slopes is None else True, - num_warps=num_warps, - num_stages=1, + USE_ALIBI=use_alibi, + PADDED_HEAD=is_padded_head, + GROUP_SIZE=group_size, + num_warps=num_warps_fwd, + num_stages=num_stages, ) - out = torch.empty((batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_padded), device=q.device, dtype=q.dtype) + if DEBUG: + print("Out_splitK:", out_splitk, out_splitk.shape) + print("metadata:", metadata, metadata.shape) + print("lse:", lse, lse.shape) + print("Out:", out, out.shape) # Merge together splitK_pow2 = triton.next_power_of_2(split_k) - use_mask = splitK_pow2 > split_k + mask_split_k = splitK_pow2 > split_k if batch_size * n_group_q * heads_per_group_q * seqlen_q >= 512: k_block_num = 1 else: @@ -664,40 +767,48 @@ def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes k_block_size = dim_padded // k_block_num grid = (batch_size * n_group_q * heads_per_group_q, seqlen_q, k_block_num) + + if DEBUG: + print("splitK_pow2:", splitK_pow2) + print("k_block_num:", k_block_num) + print("k_block_size:", k_block_size) + print("grid:", grid) + _splitK_reduce[grid]( out_splitk, metadata, out, lse, - **_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), - **_strides(metadata, "mzhg", "m2", "ms", "mm"), - **_strides(out, "oz", "om", "og", "oh", "ok"), - **_strides(lse, "lse_zhg", "lse_m"), - M_ceil=seqlen_q_ceil, - BLOCK_SIZE=k_block_size, + # Split-K output strides + stride_osk_zhg=stride_osk_zhg, + stride_osk_s=stride_osk_s, + stride_osk_m=stride_osk_m, + stride_osk_k=stride_osk_d, + # Metadata strides + stride_mzhg=stride_mzhg, + stride_m2=stride_m2, + stride_ms=stride_ms, + stride_mm=stride_mm, + # Output tensor strides + stride_oz=stride_oz, + stride_oh=stride_oh, + stride_og=stride_og, + stride_om=stride_om, + stride_ok=stride_od, + # LSE strides + stride_lse_zhg=stride_lse_zhg, + stride_lse_m=stride_lse_m, + K_BLOCK_SIZE=k_block_size, + BLOCK_DMODEL=dim_padded, + ACTUAL_BLOCK_DMODEL=dim_kc, G=n_group_q, H=heads_per_group_q, # TODO: Tune num_warps split_k=split_k, splitK_pow2=splitK_pow2, - use_mask=use_mask, + MASK_SPLITK=mask_split_k, IS_CAUSAL=causal, - num_warps=4) - - lse = lse.reshape([batch_size, n_group_q, heads_per_group_q, seqlen_q]) - if q.ndim == 4: - # BMGHK -> BMHK - assert n_group_q == 1 - out = out[:, :, 0] - lse = lse[:, 0] - if seqlen_k == 0: - out.zero_() - out = out.reshape(batch_size, heads_per_group_q * n_group_q, -1, dim_padded).contiguous() - - # output is batch_size, heads_per_group_q * group_q, seqlen_q, dim_q - if original_layout == "bshd": - # out=out.transpose(1, 2).contiguous() # this screws up heads and data. - # the data is laid out properly. Just need to reshape dims - out = out.reshape(batch_size, seqlen_q, -1, dim_padded) - - return out.narrow(-1, 0, dim_k), lse + PADDED_HEAD=is_padded_head, + num_warps=num_warps_reduce) + + return lse diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 2a59dc4e5d2..6f69cd02813 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -1,32 +1,12 @@ import torch import triton import triton.language as tl -from .utils import get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, DEBUG, AUTOTUNE - -@triton.jit -def cdiv_fn(x, y): - return (x + y - 1) // y - -@triton.jit -def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): - ms = tl.arange(0, m) - ns = tl.arange(0, n) - return philox_offset + ms[:, None] * stride + ns[None, :] - - -@triton.jit -def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32) - # TODO: use tl.randint for better performance - return tl.rand(philox_seed, rng_offsets) - - -@triton.jit -def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) - rng_keep = rng_output > dropout_p - return rng_keep +from typing import Literal, Optional, Union +from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, compute_alibi_block, compute_fp8_scaling_factors, get_shapes_from_layout, get_strides_from_layout, is_cdna, is_fp8, is_rdna, create_dropout_mask +# NOTE: triton fails to import tl.constexprs so create them here for the file +tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) +tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) # Convenience function to load with optional boundary checks. # "First" is the major dim, "second" is the minor dim. @@ -46,49 +26,16 @@ def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): tensor = tl.load(ptrs) return tensor - -@triton.jit -def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False): - # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix - # for casual mask we want something like this where (1 is kept and 0 is masked) - # seqlen_q = 2 and seqlen_k = 5 - # 1 1 1 1 0 - # 1 1 1 1 1 - # seqlen_q = 5 and seqlen_k = 2 - # 0 0 - # 0 0 - # 0 0 - # 1 0 - # 1 1 - # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal - # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False - # 1. offs_m[:,None] = [[0], - # [1], - # 2. offs_m[:,None] + seqlen_k = [[5], - # [6], - # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], - # [4], - # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], - # [4], [ 4, 3, 2, 1, 0]] - # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], - # [ -4, -3, -2, -1, 0]], - relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] - alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) - if transpose: - return alibi_block.T - else: - return alibi_block - - @triton.jit -def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, start_m, - actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, batch_philox_offset, exp_scores_ptrs, - block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs, +def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m, + actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, + block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, + descale_q, descale_k, descale_v, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_EXP2: tl.constexpr, - RETURN_SCORES: tl.constexpr): + ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, + RETURN_SCORES: tl.constexpr, ACCUMULATOR_TYPE): if USE_EXP2: RCP_LN2: tl.constexpr = 1.4426950408889634 @@ -105,7 +52,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri if PRE_LOAD_V: # We can use the same offsets as k, just with dims transposed. v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=ACCUMULATOR_TYPE) # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n # TODO: This can be optimized to only be true for the padded block. @@ -120,13 +67,18 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri size_n = start_n + OFFS_N[None, :] mask = size_n < boundary_m[:, None] qk = tl.where(mask, qk, float("-inf")) - + + # compute masks + q_mask = (OFFS_M[:, None] < actual_seqlen_q) + k_mask = ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) + p_mask = q_mask & k_mask + # -- compute qk ---- - qk += tl.dot(q, k) + if IS_FP8 : + qk += (tl.dot(q, k) * descale_q * descale_k) + else: + qk += tl.dot(q, k) qk_scaled = qk * SM_SCALE - if RETURN_SCORES: - score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) - tl.store(score_ptrs, qk_scaled, mask=score_mask) if IS_CAUSAL: causal_boundary = start_n + offs_n_causal @@ -137,8 +89,8 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k) qk_scaled += bias - if alibi_slope is not None: - # Compute the global position of each token within the sequence + if USE_ALIBI: + # compute the global position of each token within the sequence global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) global_n_positions = start_n + tl.arange(0, BLOCK_N) alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, @@ -149,10 +101,6 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri # scale and subtract max q_shifted = qk_scaled - m_ij[:, None] - if RETURN_SCORES: - # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - scores_scaled_shifted_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) - tl.store(scores_scaled_shifted_ptrs, q_shifted, mask=scores_scaled_shifted_mask) # Compute scaled QK and softmax probabilities if USE_EXP2: @@ -163,17 +111,23 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: - philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N - keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k) - if RETURN_SCORES: - # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - exp_score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) - tl.store(exp_scores_ptrs, tl.where(keep, p, -p), mask=exp_score_mask) - p = tl.where(keep, p, 0.0) + if tl_DROPOUT_USE_PYTORCH: + dropout_mask = tl.load(dropout_mask_ptrs, mask=p_mask) + else: + rng_output = tl.rand(philox_seed, philox_ptrs) # TODO: use tl.randint for better performance + dropout_mask = rng_output > dropout_p + if tl_DROPOUT_DUMP: + tl.store(dropout_mask_ptrs, dropout_mask, mask=p_mask) + + # return scores with negative values for dropped vals + sd_mask = tl.where(dropout_mask, p, -p) + tl.store(sd_mask_ptrs, sd_mask, mask=p_mask) + + # apply dropout mask in place + p = tl.where(dropout_mask, p, 0.0) elif RETURN_SCORES: # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - exp_score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) - tl.store(exp_scores_ptrs, p, mask=exp_score_mask) + tl.store(sd_mask_ptrs, p, mask=p_mask) # -- update output accumulator -- # alpha is an adjustment factor for acc and li as we loop and find new maxes @@ -190,15 +144,23 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij - acc += tl.dot(p.to(v.type.element_ty), v) + + if IS_FP8: + scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) + acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * descale_v) + else: + acc += tl.dot(p.to(v.type.element_ty), v) + k_ptrs += BLOCK_N * stride_kn v_ptrs += BLOCK_N * stride_vk if bias_ptrs is not None: bias_ptrs += BLOCK_N * stride_bn if RETURN_SCORES: - score_ptrs += BLOCK_N - scores_scaled_shifted_ptrs += BLOCK_N - exp_scores_ptrs += BLOCK_N + sd_mask_ptrs += BLOCK_N * stride_sn + + if ENABLE_DROPOUT: + dropout_mask_ptrs += BLOCK_N * stride_sn + philox_ptrs += BLOCK_N * stride_sn return acc, l_i, m_i @@ -219,7 +181,7 @@ def get_cdna_autotune_configs(): # Fall-back config. triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] def get_rdna_autotune_configs(): @@ -239,7 +201,7 @@ def get_rdna_autotune_configs(): # Fall-back config. triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), - ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] def get_autotune_configs(): @@ -263,7 +225,7 @@ def get_autotune_configs(): "MAX_SEQLENS_Q", "MAX_SEQLENS_K", "ACTUAL_BLOCK_DMODEL", - "VARLEN", + "IS_VARLEN", "HQ", "HK", ] @@ -277,34 +239,46 @@ def get_autotune_configs(): use_cuda_graph=True, ) @triton.jit -def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, +def attn_fwd(Q, K, V, bias, Cache_seqlens, Cache_batch_idx, + Descale_Q, Descale_K, Descale_V, Descale_O, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_o_z, + SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, - dropout_p, philox_seed, philox_offset_base, scores, scores_scaled_shifted, exp_scores, alibi_slopes, HQ: tl.constexpr, + dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, - MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, IS_VARLEN: tl.constexpr, IS_INFERENCE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr): + ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_OUTPUT: tl.constexpr): + # set params + ACCUMULATOR_TYPE = tl.float32 + + # compute offsets start_m = tl.program_id(0) off_h_q = tl.program_id(1) off_z = tl.program_id(2) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) - if VARLEN: + + # handle seqlen + if IS_VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) - # print("cu_seqlens_q_start:", cu_seqlens_q_start) - seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start - # We have a one-size-fits-all grid in id(0). Some seqlens might be too - # small for all start_m so for those we return early. + + # we have a one-size-fits-all grid in id(0). Some seqlens might be too small for all start_m so for those we return early. if start_m * BLOCK_M > seqlen_q: return cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + elif IS_INFERENCE: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = tl.load(Cache_seqlens + off_z) else: cu_seqlens_q_start = 0 cu_seqlens_k_start = 0 @@ -317,14 +291,14 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ # inf written to LSE. We don't need to do any GEMMs in this case. # This block of code determines what N is, and if this WG is operating # on those M rows. - n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + n_blocks = tl.cdiv(seqlen_k, BLOCK_N) if (IS_CAUSAL): # If seqlen_q == seqlen_k, the attn scores are a square matrix. # If seqlen_q != seqlen_k, attn scores are rectangular which means # the causal mask boundary is bottom right aligned, and ends at either # the top edge (seqlen_q < seqlen_k) or left edge. # This captures the decrease in n_blocks if we have a rectangular attn matrix - n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + n_blocks_seqlen = tl.cdiv((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) # This is what adjusts the block_max for the current WG, only # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks n_blocks = min(n_blocks, n_blocks_seqlen) @@ -341,9 +315,9 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ # statically known. l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m l_ptrs = l_offset + offs_m * stride_lse_m - - l = tl.full([BLOCK_M], value=0.0, dtype=tl.float32) - + + l = tl.full([BLOCK_M], value=0.0, dtype=ACCUMULATOR_TYPE) + # mask_m_offsets = start_m + tl.arange(0, BLOCK_M) # lse_mask = mask_m_offsets < causal_start_idx # softmax_lse = tl.where(lse_mask, 0.0, softmax_lse) @@ -391,34 +365,37 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ alibi_slope = None if RETURN_SCORES: - scores_offset = scores + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm - score_ptrs = scores_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn - - scores_scaled_shifted_offset = scores_scaled_shifted + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm - scores_scaled_shifted_ptrs = scores_scaled_shifted_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn - - exp_scores_offset = exp_scores + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm - exp_scores_ptrs = exp_scores_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn + sd_mask_offset = sd_mask + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm + sd_mask_ptrs = sd_mask_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn else: - score_ptrs = None - scores_scaled_shifted_ptrs = None - exp_scores_ptrs = None + sd_mask_ptrs = None if ENABLE_DROPOUT: - off_hz = off_z * HQ + off_h_q - batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k + dropout_mask_offset = dropout_mask + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm + dropout_mask_ptrs = dropout_mask_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn + batch_philox_offset = philox_offset_base + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm + philox_ptrs = batch_philox_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn else: - batch_philox_offset = 0 + dropout_mask_ptrs = None + philox_ptrs = 0 # initialize pointer to m and l - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + m_i = tl.full([BLOCK_M], float("-inf"), dtype=ACCUMULATOR_TYPE) + l_i = tl.full([BLOCK_M], 1.0, dtype=ACCUMULATOR_TYPE) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=ACCUMULATOR_TYPE) # Q is loaded once at the beginning and shared by all N blocks. q_ptrs_mask = offs_m[:, None] < seqlen_q if PADDED_HEAD: q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) + # Load scale factors if IS_FP8. + if IS_FP8: + descale_q = tl.load(Descale_Q + off_z * stride_descale_q_z + off_h_q) + descale_k = tl.load(Descale_K + off_z * stride_descale_k_z + off_h_k) + descale_v = tl.load(Descale_V + off_z * stride_descale_v_z + off_h_k) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) @@ -439,16 +416,17 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ # value because there is no masking. Similarly we do not need padding. if n_full_blocks > 0: block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, - start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, - exp_scores_ptrs, + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, + start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, + sd_mask_ptrs, dropout_mask_ptrs, # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, block_max, 0, 0, 0, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs, + block_min, block_max, 0, 0, 0, alibi_slope, + descale_q, descale_k, descale_v, IS_FP8, FP8_MAX, # IS_CAUSAL, .... False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, # _, MASK_STEPS, ... PRE_LOAD_V, False, ENABLE_DROPOUT, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES) + ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, ACCUMULATOR_TYPE=ACCUMULATOR_TYPE) block_min = block_max block_max = n_blocks * BLOCK_N @@ -464,23 +442,25 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ if USE_BIAS: bias_ptrs += n_full_blocks * BLOCK_N * stride_bn if RETURN_SCORES: - score_ptrs += n_full_blocks * BLOCK_N - scores_scaled_shifted_ptrs += n_full_blocks * BLOCK_N - exp_scores_ptrs += n_full_blocks * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, - start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, - exp_scores_ptrs, block_min, block_max, offs_n_causal, masked_blocks, - n_extra_tokens, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs, + sd_mask_ptrs += n_full_blocks * BLOCK_N * stride_sn + if ENABLE_DROPOUT: + dropout_mask_ptrs += n_full_blocks * BLOCK_N * stride_sn + philox_ptrs += n_full_blocks * BLOCK_N * stride_sn + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, + start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, + sd_mask_ptrs, dropout_mask_ptrs, block_min, block_max, offs_n_causal, masked_blocks, + n_extra_tokens, alibi_slope, descale_q, descale_k, descale_v, IS_FP8, FP8_MAX, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, # _, MASK_STEPS, ... PRE_LOAD_V, True, ENABLE_DROPOUT, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES) + ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, ACCUMULATOR_TYPE=ACCUMULATOR_TYPE) # epilogue # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. l_recip = 1 / l_i[:, None] acc = acc * l_recip if ENABLE_DROPOUT: - acc = acc / (1 - dropout_p) + dropout_scale = 1 / (1 - dropout_p) + acc = acc * dropout_scale # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, # then we have one block with a row of all NaNs which come from computing # softmax over a row of all -infs (-inf - inf = NaN). We check for that here @@ -488,7 +468,6 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ end_m_idx = (start_m + 1) * BLOCK_M start_m_idx = start_m * BLOCK_M causal_start_idx = seqlen_q - seqlen_k - acc = acc.to(Out.type.element_ty) if IS_CAUSAL: if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) @@ -496,7 +475,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] z = 0.0 acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) - + # write back LSE(Log Sum Exponents), the log of the normalization constant l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m l_ptrs = l_offset + offs_m * stride_lse_m @@ -510,7 +489,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ softmax_lse *= LN2 else: softmax_lse = m_i + tl.math.log(l_i) - + if IS_CAUSAL: # zero out nans caused by -infs when doing causal lse_mask = (start_m_idx + tl.arange(0, BLOCK_M)) < causal_start_idx @@ -534,55 +513,83 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) if PADDED_HEAD: o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) - tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) + + if FP8_OUTPUT: + scale_acc, descale_acc = compute_fp8_scaling_factors(acc, FP8_MAX) + tl.store(Descale_O + off_z * stride_descale_o_z + off_h_q, descale_acc) + tl.store(o_ptrs, (acc * scale_acc).to(Out.type.element_ty), mask=o_ptrs_mask) + else: + tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) def attention_prefill_forward_triton_impl( - q, - k, - v, - o, - sm_scale, - alibi_slopes, - causal, - bias, - dropout_p, - layout, - cu_seqlens_q, - cu_seqlens_k, - max_seqlens_q, - max_seqlens_k, - return_scores, - use_exp2): + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + bias: Optional[torch.Tensor], + layout: Literal["bshd", "bhsd", "thd"], + # varlen + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + max_seqlens_q: int, + max_seqlens_k: int, + # inference + cache_seqlens: Optional[Union[(int, torch.Tensor)]], + cache_batch_idx: Optional[torch.Tensor], + # dropout + dropout_p: float, + philox_seed: Optional[int], + philox_offset: Optional[int], + # misc + return_softmax: bool, + use_exp2: bool, + # fp8 + descale_q: Optional[torch.Tensor], + descale_k: Optional[torch.Tensor], + descale_v: Optional[torch.Tensor], + descale_o: Optional[torch.Tensor], +): + IS_FP8 = is_fp8(q) + if IS_FP8: + FP8_MAX: tl.constexpr = torch.finfo(q.dtype).max + + assert is_fp8(q) and is_fp8(k) and is_fp8(v), f"Non fp8 type found: q.dtype={q.dtype}, k.dtype={k.dtype}, v.dtype={v.dtype}. All tensors must be fp8." + + if is_fp8(o): + FP8_OUTPUT = True + assert descale_o is not None, f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor for the output." + else: + FP8_OUTPUT = False - if DEBUG: - print() - print("attention_prefill_forward_triton_impl") - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("o:", o, o.shape) - print("sm_scale:", sm_scale) - print("alibi_slopes:", alibi_slopes) - print("causal:", causal) - print("bias:", bias) - print("dropout_p:", dropout_p) - print("layout:", layout) - print("cu_seqlens_q:", cu_seqlens_q) - print("cu_seqlens_k:", cu_seqlens_k) - print("max_seqlens_q:", max_seqlens_q) - print("max_seqlens_k:", max_seqlens_k) - print("return_scores:", return_scores) - print("use_exp2:", use_exp2) - - # check if varlen + # Get strides for the kernel + stride_descale_q_z = descale_q.stride(0) if descale_q is not None else None + stride_descale_k_z = descale_k.stride(0) if descale_k is not None else None + stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None + stride_descale_o_z = descale_o.stride(0) if descale_o is not None else None + else: + FP8_MAX = None + FP8_OUTPUT = False + descale_q = descale_k = descale_v = descale_o = None + stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_o_z = None + + # check flags is_varlen = layout == "thd" + use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) + is_inference = False if cache_seqlens is None else True + if is_inference: + assert layout == "bshd", f"{layout} layout is not supported with inference. Use bshd layout" + if DEBUG: + print(f"is_inference:", is_inference) # NOTE: a large bias tensor leads to overflow during pointer arithmetic if (bias is not None): assert (bias.numel() < 2**31) - batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k) + batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k = get_shapes_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k) q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout) # Get closest power of 2 over or equal to 32. @@ -593,60 +600,50 @@ def attention_prefill_forward_triton_impl( grid = lambda META: (triton.cdiv(max_seqlens_q, META['BLOCK_M']), nheads_q, batch) - if return_scores: - scores = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, - dtype=torch.float32) - scores_scaled_shifted = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, - dtype=torch.float32) - scores_strides = (scores.stride(0), scores.stride(1), scores.stride(2), scores.stride(3)) - else: - scores = None - scores_scaled_shifted = None - scores_strides = (0, 0 , 0 , 0) - - # exp_scores is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out + # sd_mask is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing - # only. This return holds no useful output aside from debugging. - if return_scores: - exp_scores = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, + # only. This return holds no useful output aside from debugging. + use_dropout = (dropout_p > 0.0) + if use_dropout or return_softmax: + sd_mask = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, + dtype=torch.float32) + if DROPOUT_USE_PYTORCH: + dropout_mask = create_dropout_mask(dropout_p, (batch, nheads_q, max_seqlens_q, max_seqlens_k), seed = philox_seed) + else: + dropout_mask = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, dtype=torch.float32) + scores_strides = (sd_mask.stride(0), sd_mask.stride(1), sd_mask.stride(2), sd_mask.stride(3)) else: - exp_scores = None + sd_mask = None + dropout_mask = None + scores_strides = (0, 0, 0, 0) # stores LSE the log of the normalization constant / sum of expoential score(unnormalzied probablities) if is_varlen: - softmax_lse = torch.empty((q.shape[0], nheads_q), device=q.device, dtype=torch.float32) - stride_lse_m, stride_lse_h = softmax_lse.stride() + total_seqlen_q, _, _ = q.shape + softmax_lse = torch.zeros((nheads_q, total_seqlen_q), device=q.device, dtype=torch.float32) + stride_lse_h, stride_lse_m = softmax_lse.stride() stride_lse_z = 0 else: - softmax_lse = torch.empty((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32) + softmax_lse = torch.zeros((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32) stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() - # Seed the RNG so we get reproducible results for testing. - philox_seed = 0x1BF52 - philox_offset = 0x1D4B42 - if bias is not None: bias_strides = (bias.stride(0), bias.stride(1),bias.stride(2), bias.stride(3)) else: bias_strides = (0, 0, 0, 0) - if alibi_slopes is not None: - alibi_strides = (alibi_slopes.stride(0), alibi_slopes.stride(1)) - else: - alibi_strides = (0, 0) - - - attn_fwd[grid](q, k, v, bias, sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides, - *bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, - dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, scores=scores, - scores_scaled_shifted=scores_scaled_shifted, exp_scores=exp_scores, alibi_slopes=alibi_slopes, + attn_fwd[grid](q, k, v, bias, cache_seqlens, cache_batch_idx, + descale_q, descale_k, descale_v, descale_o, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_o_z, + sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides, + *bias_strides, stride_az, stride_ah, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, + dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, - MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen, + MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, IS_VARLEN=is_varlen, IS_INFERENCE=is_inference, BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True, - USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p - > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_scores) + USE_ALIBI=use_alibi, ENABLE_DROPOUT=dropout_p + > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT) - return o, softmax_lse, exp_scores, grid, head_size, philox_seed, philox_offset, scores, scores_scaled_shifted + return softmax_lse, sd_mask if return_softmax else None diff --git a/flash_attn/flash_attn_triton_amd/fwd_ref.py b/flash_attn/flash_attn_triton_amd/fwd_ref.py index 1cc51d17e73..baefb2410c1 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_ref.py +++ b/flash_attn/flash_attn_triton_amd/fwd_ref.py @@ -1,9 +1,12 @@ import torch import math -from .utils import DEBUG +from typing import Literal, Optional +from .utils import DEBUG, compute_alibi_tensor_ref -def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2): - if DEBUG: +DEBUG_CORE = False + +def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2): + if DEBUG_CORE: print() print("attention_forward_core_ref_impl") print("q:", q, q.shape) @@ -11,18 +14,42 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2): print("v:", v, v.shape) print("sm_scale:", sm_scale) print("causal:", causal) + print("dropout_p:", dropout_p) + print("philox_seed:", philox_seed) + print("philox_offset:", philox_offset) print("use_exp2:", use_exp2) + + # cast to float32 + q = q.to(torch.float32) + k = k.to(torch.float32) + v = v.to(torch.float32) # Compute attention scores - attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32)) - if DEBUG: + attention_scores = torch.matmul(q, k.transpose(-2, -1)) + if DEBUG_CORE: print("attention_scores:", attention_scores, attention_scores.shape) # Scale scores attention_scaled_scores = sm_scale * attention_scores - if DEBUG: + if DEBUG_CORE: print("attention_scaled_scores:", attention_scaled_scores, attention_scaled_scores.shape) + # Apply ALiBi if slopes are provided + if alibi_slopes is not None: + L_q, L_k = q.shape[1], k.shape[1] + if DEBUG_CORE: + print("alibi_slopes:", alibi_slopes, alibi_slopes.shape) + alibi_bias = compute_alibi_tensor_ref(alibi_slopes, L_q, L_k) + if DEBUG_CORE: + print("alibi_bias:", alibi_bias, alibi_bias.shape) + alibi_bias = alibi_bias.reshape(-1, L_q, L_k) + if DEBUG_CORE: + print("alibi_bias_flat:", alibi_bias, alibi_bias.shape) + attention_scaled_scores = attention_scaled_scores + alibi_bias + if DEBUG_CORE: + print("attention_scaled_scores after alibi:", attention_scaled_scores, attention_scaled_scores.shape) + + # Apply causal mask if necessary if causal: L_q, L_k = q.shape[1], k.shape[1] @@ -30,19 +57,18 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2): col_idx = torch.arange(L_k, device=q.device).unsqueeze(0) col_offset = L_q-L_k causal_mask = row_idx >= (col_offset + col_idx) - if DEBUG: + if DEBUG_CORE: print("causal_mask:", causal_mask) # set -inf to places the causal mask is false attention_scaled_scores = attention_scaled_scores.masked_fill( torch.logical_not(causal_mask.unsqueeze(0)), float('-inf') ) - if DEBUG: + if DEBUG_CORE: print("attention_scaled_scores after causal:", attention_scaled_scores, attention_scaled_scores.shape) - # Compute max for numerical stability max_scores = torch.max(attention_scaled_scores, dim=-1, keepdim=True)[0] - if DEBUG: + if DEBUG_CORE: print("max_scores:", max_scores, max_scores.shape) if causal: # Replace -inf in max_scores with zeros to avoid NaN in subtraction @@ -54,7 +80,7 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2): # Shift scores attention_shifted_scaled_scores = attention_scaled_scores - max_scores - if DEBUG: + if DEBUG_CORE: print("attention_shifted_scaled_scores:", attention_shifted_scaled_scores, attention_shifted_scaled_scores.shape) # Exponentiate @@ -64,12 +90,12 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2): else: exp_scores = torch.exp(attention_shifted_scaled_scores) - if DEBUG: + if DEBUG_CORE: print("exp_scores:", exp_scores, exp_scores.shape) # Sum of exponentials sum_exp_scores = torch.sum(exp_scores, dim=-1, keepdim=True) - if DEBUG: + if DEBUG_CORE: print("sum_exp_scores:", sum_exp_scores, sum_exp_scores.shape) if causal: # if sum of exp scores is 0.0 it means scores where -inf, we cannot compute softmax and softmax_lse. Setting to 1 deals with -inf case cleanly @@ -78,15 +104,32 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2): torch.ones_like(sum_exp_scores), sum_exp_scores ) - if DEBUG: + if DEBUG_CORE: print("sum_exp_scores:", sum_exp_scores, sum_exp_scores.shape) # Compute softmax probabilities - softmax = exp_scores / sum_exp_scores - - if DEBUG: - print("softmax:", softmax, softmax.shape) - + p = exp_scores / sum_exp_scores + + if DEBUG_CORE: + print("softmax:", p, p.shape) + + # apply dropout if specified + if dropout_p > 0.0: + rand_vals = torch.rand(p.shape, generator=torch.Generator(device=p.device).manual_seed(philox_seed), device=p.device, dtype=p.dtype) + dropout_mask, dropout_scale = rand_vals > dropout_p, (1.0 / (1 - dropout_p)) + if DEBUG_CORE: + print("dropout_scale:", dropout_scale) + print("dropout_mask:", dropout_mask) + # Apply dropout mask and scale + # Set -1 for dropped positions and 1 for kept positions in exp_scores + sd_mask = torch.where(dropout_mask, exp_scores, -exp_scores) + p = torch.where(dropout_mask, p , torch.zeros_like(p)) * dropout_scale + if DEBUG_CORE: + print("softmax after dropout:", p) + print("sd_mask:", sd_mask) + else: + sd_mask = exp_scores + # Compute log-sum-exp if use_exp2: LN2 = math.log(2) @@ -99,17 +142,22 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2): softmax_lse = max_scores + torch.log(sum_exp_scores) softmax_lse = softmax_lse.squeeze(-1) - if DEBUG: + if DEBUG_CORE: print("softmax_lse:", softmax_lse, softmax_lse.shape) # Compute output - o = torch.matmul(softmax, v.to(torch.float32)).to(torch.float16) - if DEBUG: + o = torch.matmul(p, v) + if DEBUG_CORE: print("o:", o, o.shape) - return o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores + # cast back to original dtype + o = o.to(torch.float16) + # softmax_lse = softmax_lse.to(torch.float16) # NOTE: if you cast lse to fp16 it cause accuracy issues. keep fp32 + sd_mask = sd_mask.to(torch.float16) + + return o, softmax_lse, sd_mask -def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, use_exp2): +def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2): """Compute reference output and softmax_lse using PyTorch's built-in function""" # Ensure the layout is 'bhsd' @@ -120,34 +168,54 @@ def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout elif layout != "bhsd": raise ValueError(f"Unknown layout {layout}") - # Prepare tensors in [batch_size * num_heads, seq_len, head_dim] format - batch_size, num_heads, seq_len_q, head_dim = q.shape - seq_len_k = k.shape[2] - - # Merge batch and heads dimensions - q = q.reshape(batch_size * num_heads, seq_len_q, head_dim) - k = k.reshape(batch_size * num_heads, seq_len_k, head_dim) - v = v.reshape(batch_size * num_heads, seq_len_k, head_dim) + # Prepare tensors + batch_size, nheads_q, seq_len_q, head_dim = q.shape + batch_size, nheads_k, seq_len_k, head_dim = k.shape + group_size = nheads_q // nheads_k + if nheads_q % nheads_k != 0: + raise ValueError("nheads_q must be divisible by nheads_k") + + if group_size != 1: + # MQA or GQA case + # Reshape q to [batch_size, nheads_k, group_size, seq_len_q, head_dim] + q = q.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) + # Expand k and v to match group_size + k = k.unsqueeze(2).expand(-1, -1, group_size, -1, -1) + v = v.unsqueeze(2).expand(-1, -1, group_size, -1, -1) + # Flatten the first three dimensions for computation + q = q.reshape(batch_size * nheads_k * group_size, seq_len_q, head_dim) + k = k.reshape(batch_size * nheads_k * group_size, seq_len_k, head_dim) + v = v.reshape(batch_size * nheads_k * group_size, seq_len_k, head_dim) + else: + q = q.reshape(batch_size * nheads_q, seq_len_q, head_dim) + k = k.reshape(batch_size * nheads_k, seq_len_k, head_dim) + v = v.reshape(batch_size * nheads_k, seq_len_k, head_dim) # Call the core attention function - o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores = attention_forward_core_ref_impl( - q, k, v, sm_scale, causal, use_exp2 + o, softmax_lse, sd_mask = attention_forward_core_ref_impl( + q, k, v, sm_scale, causal, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2 ) - # Reshape outputs back to [batch_size, num_heads, seq_len, head_dim] - o = o.reshape(batch_size, num_heads, seq_len_q, head_dim) - softmax_lse = softmax_lse.reshape(batch_size, num_heads, seq_len_q) - exp_scores = exp_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k) - softmax = softmax.reshape(batch_size, num_heads, seq_len_q, seq_len_k) - attention_shifted_scaled_scores = attention_shifted_scaled_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k) - attention_scaled_scores = attention_scaled_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k) - attention_scores = attention_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k) + if group_size != 1: + # Reshape outputs back to original dimensions + o = o.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) + o = o.reshape(batch_size, nheads_q, seq_len_q, head_dim) + softmax_lse = softmax_lse.reshape(batch_size, nheads_k, group_size, seq_len_q) + softmax_lse = softmax_lse.reshape(batch_size, nheads_q, seq_len_q) + sd_mask = sd_mask.reshape(batch_size, nheads_k, group_size, seq_len_q, seq_len_k) + sd_mask = sd_mask.reshape(batch_size, nheads_q, seq_len_q, seq_len_k) + else: + # Standard case + o = o.reshape(batch_size, nheads_q, seq_len_q, head_dim) + softmax_lse = softmax_lse.reshape(batch_size, nheads_q, seq_len_q) + sd_mask = sd_mask.reshape(batch_size, nheads_q, seq_len_q, seq_len_k) # Restore original layout if necessary if layout == "bshd": o = o.transpose(1, 2) - return o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores + return o, softmax_lse, sd_mask + def attention_varlen_forward_pytorch_ref_impl( q, @@ -160,6 +228,10 @@ def attention_varlen_forward_pytorch_ref_impl( cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, use_exp2 ): # Ensure the layout is 'thd' @@ -167,15 +239,21 @@ def attention_varlen_forward_pytorch_ref_impl( raise ValueError(f"Unsupported layout {layout}. Expected 'thd'.") batch_size = cu_seqlens_q.shape[0] - 1 - num_heads = q.shape[1] + nheads_q, nheads_k = q.shape[1], k.shape[1] head_dim = q.shape[2] # Pre-allocate outputs total_L_q = q.shape[0] total_L_k = k.shape[0] - o = torch.empty((total_L_q, num_heads, head_dim), dtype=q.dtype, device=q.device) - softmax_lse = torch.empty((total_L_q, num_heads), dtype=torch.float32, device=q.device) + o = torch.zeros((total_L_q, nheads_q, head_dim), dtype=q.dtype, device=q.device) + softmax_lse = torch.zeros((total_L_q, nheads_q), dtype=torch.float32, device=q.device) + sd_mask = torch.zeros((batch_size, nheads_q, max_seqlen_q, max_seqlen_k), dtype=torch.float32, device=q.device) + + # Compute group_size for MQA/GQA handling + group_size = nheads_q // nheads_k + if nheads_q % nheads_k != 0: + raise ValueError("nheads_q must be divisible by nheads_k") for i in range(batch_size): # Get the start and end indices for the current sequence @@ -184,136 +262,126 @@ def attention_varlen_forward_pytorch_ref_impl( start_k = cu_seqlens_k[i].item() end_k = cu_seqlens_k[i + 1].item() + seqlen_q = end_q - start_q + seqlen_k = end_k - start_k + + if DEBUG: + print(f"Batch {i} with seqlen_q = {seqlen_q}, seqlen_k = {seqlen_k}, Hq= {nheads_q}, Hk = {nheads_k}") + # Extract q_i, k_i, v_i - q_i = q[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim] - k_i = k[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim] - v_i = v[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim] + q_i = q[start_q:end_q, :, :] # [L_q_i, nheads_q, head_dim] + k_i = k[start_k:end_k, :, :] # [L_k_i, nheads_k, head_dim] + v_i = v[start_k:end_k, :, :] # [L_k_i, nheads_k, head_dim] - # Permute to [num_heads, L_q_i, head_dim] + # Permute to [nheads, L_q_i, head_dim] q_i = q_i.permute(1, 0, 2) k_i = k_i.permute(1, 0, 2) v_i = v_i.permute(1, 0, 2) + # Handle MQA/GQA by adjusting shapes based on group_size + if group_size != 1: + # Reshape q_i to [nheads_k, group_size, L_q_i, head_dim] + q_i = q_i.reshape(nheads_k, group_size, seqlen_q, head_dim) + # Expand k_i and v_i to match group_size + k_i = k_i.unsqueeze(1).expand(-1, group_size, -1, -1) + v_i = v_i.unsqueeze(1).expand(-1, group_size, -1, -1) + # Flatten the first two dimensions for computation + q_i = q_i.reshape(nheads_k * group_size, seqlen_q, head_dim) + k_i = k_i.reshape(nheads_k * group_size, seqlen_k, head_dim) + v_i = v_i.reshape(nheads_k * group_size, seqlen_k, head_dim) + else: + # Standard case + q_i = q_i.reshape(nheads_q, seqlen_q, head_dim) + k_i = k_i.reshape(nheads_k, seqlen_k, head_dim) + v_i = v_i.reshape(nheads_k, seqlen_k, head_dim) + + if alibi_slopes is not None: + alibi_slopes_i = alibi_slopes[i] + else: + alibi_slopes_i = None + # Call the core attention function for this sequence - ( - o_i, - softmax_lse_i, - exp_scores_i, - softmax_i, - attention_shifted_scaled_scores_i, - attention_scaled_scores_i, - attention_scores_i, - ) = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, use_exp2) - - # Convert back to 'thd' layout and float16 - o_i = o_i.permute(1, 0, 2).to(torch.float16) # [L_q_i, num_heads, head_dim] + o_i, softmax_lse_i, sd_mask_i = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, dropout_p, philox_seed, philox_offset, alibi_slopes_i, use_exp2) + + # Reshape outputs back to original dimensions + if group_size != 1: + # Reshape outputs to [nheads_k, group_size, seqlen_q, head_dim] + o_i = o_i.reshape(nheads_k, group_size, seqlen_q, head_dim) + # Combine the first two dimensions back to nheads_q + o_i = o_i.reshape(nheads_q, seqlen_q, head_dim) + # Reshape softmax_lse_i similarly + softmax_lse_i = softmax_lse_i.reshape(nheads_k, group_size, seqlen_q) + softmax_lse_i = softmax_lse_i.reshape(nheads_q, seqlen_q) + else: + # Outputs are already in the correct shape + pass + + # Convert back to 'thd' layout + o_i = o_i.permute(1, 0, 2) # [L_q_i, nheads_q, head_dim] + softmax_lse_i = softmax_lse_i.permute(1, 0) # [L_q_i, nheads_q] + sd_mask_i = sd_mask_i # [nheads_q, L_q_i, L_k_i] # Place outputs in pre-allocated tensors o[start_q:end_q, :, :] = o_i - softmax_lse[start_q:end_q, :] = softmax_lse_i.transpose(0, 1) # Transpose to [L_q_i, num_heads] - - # For variable-sized outputs, map them into the preallocated tensors - # exp_scores_i: [num_heads, L_q_i, L_k_i] -> [L_q_i, num_heads, L_k_i] - exp_scores_i = exp_scores_i.permute(1, 0, 2) - softmax_i = softmax_i.permute(1, 0, 2) - attention_shifted_scaled_scores_i = attention_shifted_scaled_scores_i.permute(1, 0, 2) - attention_scaled_scores_i = attention_scaled_scores_i.permute(1, 0, 2) - attention_scores_i = attention_scores_i.permute(1, 0, 2) - - return ( - o, - softmax_lse, - None, - None, - None, - None, - None, - ) + softmax_lse[start_q:end_q, :] = softmax_lse_i + sd_mask[i, :, :seqlen_q, :seqlen_k] = sd_mask_i + + return o, softmax_lse, sd_mask -def attention_forward_pytorch_ref_impl( - q, - k, - v, - sm_scale, - causal, - layout, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - use_exp2 - ): - if DEBUG: - print() - print("attention_forward_pytorch_ref_impl") - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("sm_scale:", sm_scale) - print("causal:", causal) - print("cu_seqlens_q:", cu_seqlens_q) - print("cu_seqlens_k:", cu_seqlens_k) - print("max_seqlen_q:", max_seqlen_q) - print("max_seqlen_k:", max_seqlen_k) - print("use_exp2:", use_exp2) - # compute reference +def attention_forward_pytorch_ref_impl( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + layout: Literal["bshd", "bhsd", "thd"], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + philox_seed: Optional[int], + philox_offset: Optional[int], + use_exp2: bool +): + # compute reference if layout == "thd": - ( - o_ref, - softmax_lse_ref, - exp_scores_ref, - softmax_ref, - attention_shifted_scaled_scores_ref, - attention_scaled_scores_ref, - attention_scores_ref, - ) = attention_varlen_forward_pytorch_ref_impl( + o_ref, softmax_lse_ref, sd_mask_ref = attention_varlen_forward_pytorch_ref_impl( q.clone(), k.clone(), v.clone(), sm_scale, - causal, + causal, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, use_exp2, ) else: - ( - o_ref, - softmax_lse_ref, - exp_scores_ref, - softmax_ref, - attention_shifted_scaled_scores_ref, - attention_scaled_scores_ref, - attention_scores_ref, - ) = attention_vanilla_forward_pytorch_ref_impl( - q.clone(), k.clone(), v.clone(), sm_scale, causal, layout, use_exp2 - ) - - if DEBUG: - print() - print("attention_forward_pytorch_ref_impl outputs") - print("o_ref:", o_ref, o_ref.shape) - print("softmax_lse_ref:", softmax_lse_ref, softmax_lse_ref.shape) - print("exp_scores_ref:", exp_scores_ref, exp_scores_ref.shape if exp_scores_ref is not None else None) - - return ( - o_ref, - softmax_lse_ref, - exp_scores_ref, - softmax_ref, - attention_shifted_scaled_scores_ref, - attention_scaled_scores_ref, - attention_scores_ref, - ) - - -def compute_alibi_tensor_ref(alibi_slopes, seqlen_q, seqlen_k): - q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) - k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K) - relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K) - return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K) \ No newline at end of file + o_ref, softmax_lse_ref, sd_mask_ref = attention_vanilla_forward_pytorch_ref_impl( + q.clone(), + k.clone(), + v.clone(), + sm_scale, + causal, + layout, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, + use_exp2) + + # copy back to ouput tensor + out.copy_(o_ref.to(out.dtype)) + + return softmax_lse_ref, sd_mask_ref diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index 59a306d5d6a..06ab7d24d56 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -2,34 +2,43 @@ import os from .fwd_prefill import attention_prefill_forward_triton_impl from .bwd_prefill import attention_prefill_backward_triton_impl +from .bwd_prefill_split import attention_prefill_backward_triton_split_impl +from .bwd_prefill_fused import _flash_attn_backward as attention_prefill_backward_triton_fused_impl +from .bwd_prefill_onekernel import attention_prefill_backward_triton_split_oneKernel_impl from .fwd_decode import attention_decode_forward_triton_impl from .fwd_ref import attention_forward_pytorch_ref_impl from .bwd_ref import attention_backward_pytorch_ref_impl -from .utils import MetaData, get_shape_from_layout, DEBUG - -USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes') - -def fwd(q, - k, - v, - o, - alibi_slopes, - dropout_p, - softmax_scale, - causal, - window_size_left, - window_size_right, - softcap, - return_softmax, - gen_): - +from .utils import DEBUG, USE_REF, MetaData, get_shapes_from_layout, is_fp8 +from einops import rearrange, repeat +from flash_attn.layers.rotary import apply_rotary_emb +from typing import Literal, Optional, Union + +def fwd(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + dropout_p: float, + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + return_softmax: bool, + gen_: Optional[torch.Tensor] = None, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_o: Optional[torch.Tensor] = None + ): + if DEBUG: print() - print("flash_attn_triton_amd.py::fwd") + print("flash_attn_triton_amd.py::fwd inputs") print("q:", q, q.shape) print("k:", k, k.shape) print("v:", v, v.shape) - print("o:", o) + print("out:", out, out.shape if out is not None else None) print("alibi_slopes:", alibi_slopes) print("dropout_p:", dropout_p) print("softmax_scale:", softmax_scale) @@ -37,15 +46,17 @@ def fwd(q, print("window_size_left:", window_size_left) print("window_size_right:", window_size_right) print("softcap:", softcap) - print("softcap:", softcap) print("return_softmax:", return_softmax) - - - if dropout_p != 0.0: - raise ValueError("dropout is not supported on AMD's Triton Backend yet") - - if o is None: - o = torch.empty_like(q) + print("descale_q:", descale_q, descale_q.shape if descale_q is not None else None) + print("descale_k:", descale_k, descale_k.shape if descale_k is not None else None) + print("descale_v:", descale_v, descale_v.shape if descale_v is not None else None) + print("descale_o:", descale_o, descale_o.shape if descale_o is not None else None) + + if is_fp8(q): + assert out is not None, "fp8 output tensor should be passed in." + assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"For fp8, you need to pass descale factors for q, k and v" + else: + out = torch.zeros_like(q) if out is None else out.zero_() # Setup metadata metadata = MetaData(sm_scale=softmax_scale) @@ -55,111 +66,127 @@ def fwd(q, if return_softmax: metadata.return_scores = True - batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, metadata.layout) - + batch, nheads_q, nheads_k, head_size, _, _ = get_shapes_from_layout(q, k, metadata.layout) + if causal: - metadata.need_causal() - + metadata.need_causal(True) + if alibi_slopes is not None: metadata.need_alibi(alibi_slopes, batch, nheads_q) - - if dropout_p > 0.0: - metadata.need_dropout(dropout_p, return_softmax) - - # Check arguments - metadata.check_args(q, k, v, o) + + # store rng state + metadata.need_dropout(dropout_p, return_softmax) + rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) # as_tensors uses the underlying data and doesnot cast + + # check arguments + metadata.check_args(q, k, v, out) + + # call implementation if USE_REF: if DEBUG: print("Using reference implementation") - (output, - softmax_lse, - exp_scores, - _, - _, - _, - _) = attention_forward_pytorch_ref_impl( - q, - k, + softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( + q, + k, v, - metadata.sm_scale, + out, + metadata.sm_scale, + metadata.alibi_slopes, metadata.causal, - metadata.layout, - metadata.cu_seqlens_q, + metadata.layout, + metadata.cu_seqlens_q, metadata.cu_seqlens_k, - metadata.max_seqlens_q, + metadata.max_seqlens_q, metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, metadata.use_exp2) - o.copy_(output) + softmax_lse=softmax_lse_ref + sd_mask=sd_mask_ref else: if DEBUG: print("Using Triton implementation") - (_, - softmax_lse, - exp_scores, - _, - _, - _, - _, - _, - _) = attention_prefill_forward_triton_impl( - q, - k, - v, - o, - metadata.sm_scale, - metadata.alibi_slopes, - metadata.causal, - metadata.bias, - metadata.dropout_p, - metadata.layout, - metadata.cu_seqlens_q, + softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl( + q, + k, + v, + out, + metadata.sm_scale, + metadata.alibi_slopes, + metadata.causal, + None, + metadata.layout, + metadata.cu_seqlens_q, metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.return_scores, - metadata.use_exp2) + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.cache_seqlens, + metadata.cache_batch_idx, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + metadata.return_scores, + metadata.use_exp2, + descale_q, + descale_k, + descale_v, + descale_o) + softmax_lse=softmax_lse_triton + sd_mask=sd_mask_triton if DEBUG: - print("fwd outputs") - print("o:", o, o.shape) + print("flash_attn_triton_amd.py::fwd outputs") + print("o:", out, out.shape) + if is_fp8(out): + print("descale_o:", descale_o, descale_o.shape if descale_o is not None else None) print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("exp_scores:", exp_scores, exp_scores.shape if exp_scores is not None else None ) + print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None ) - return o, softmax_lse, exp_scores, None + return out, softmax_lse, sd_mask, rng_state +BWD_MODE = os.environ.get('BWD_MODE', 'split').lower() def bwd( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - alibi_slopes, - dropout_p, - softmax_scale, - causal, - window_size_left, - window_size_right, - softcap, - deterministic, - gen_, - rng_state, + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + dropout_p: float, + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + deterministic: bool, + gen_: Optional[torch.Tensor] = None, + rng_state:Optional[torch.Tensor] = None, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_o: Optional[torch.Tensor] = None, + descale_do: Optional[torch.Tensor] = None, + descale_dq: Optional[torch.Tensor] = None, + descale_dk: Optional[torch.Tensor] = None, + descale_dv: Optional[torch.Tensor] = None, ): if DEBUG: print() - print("flash_attn_triton_amd.py::bwd") + print("flash_attn_triton_amd.py::bwd inputs") print("dout:", dout, dout.shape) print("q:", q, q.shape) print("k:", k, k.shape) print("v:", v, v.shape) print("out:", out, out.shape) print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("dq:", dq, dq.shape) - print("dk:", dk, dk.shape) - print("dv:", dv, dv.shape) + print("dq:", dq, dq.shape if dq is not None else None) + print("dk:", dk, dk.shape if dk is not None else None) + print("dv:", dv, dv.shape if dv is not None else None) print("alibi_slopes:", alibi_slopes) print("dropout_p:", dropout_p) print("out:", out) @@ -170,37 +197,30 @@ def bwd( print("deterministic:", deterministic) print("gen_:", gen_) print("rng_state:", rng_state) + print("descale_q:", descale_q, descale_q.shape if descale_q is not None else None) + print("descale_k:", descale_k, descale_k.shape if descale_k is not None else None) + print("descale_v:", descale_v, descale_v.shape if descale_v is not None else None) + print("descale_o:", descale_o, descale_o.shape if descale_o is not None else None) + print("descale_do:", descale_do, descale_do.shape if descale_do is not None else None) + print("descale_dq:", descale_dq, descale_dq.shape if descale_dq is not None else None) + print("descale_dk:", descale_dk, descale_dk.shape if descale_dk is not None else None) + print("descale_dv:", descale_dv, descale_dv.shape if descale_dv is not None else None) + + dq = torch.zeros_like(q) if dq is None else dq.zero_() + dk = torch.zeros_like(k) if dk is None else dk.zero_() + dv = torch.zeros_like(v) if dv is None else dv.zero_() + + if rng_state is not None: + philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() + else: + philox_seed, philox_offset = None, None - if dropout_p != 0.0: - raise ValueError("dropout is not supported on AMD yet") - + # call implementation if USE_REF: if DEBUG: print("Using reference implementation") - dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl( - dout, - q, - k, - v, - out, - softmax_lse, - softmax_scale, - causal, - "bshd", - None, - None, - None, - None, - False, - ) - dq.copy_(dq_ref) - dk.copy_(dk_ref) - dv.copy_(dv_ref) - delta = delta_ref - else: - if DEBUG: - print("Using Triton implementation") - dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl( + + delta_ref = attention_backward_pytorch_ref_impl( dout, q, k, @@ -218,39 +238,144 @@ def bwd( None, None, None, + dropout_p, + philox_seed, + philox_offset, False, ) - delta = delta_triton + delta = delta_ref + else: + if DEBUG: + print("Using Triton implementation") + if BWD_MODE == "split": + delta_triton = attention_prefill_backward_triton_split_impl( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + softmax_scale, + alibi_slopes, + causal, + "bshd", + None, + None, + None, + None, + dropout_p, + philox_seed, + philox_offset, + False, + descale_q, + descale_k, + descale_v, + descale_o, + descale_do, + descale_dq, + descale_dk, + descale_dv, + ) + delta = delta_triton + elif BWD_MODE == "fused": + delta_triton = attention_prefill_backward_triton_fused_impl( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + softmax_scale, + alibi_slopes, + causal, + None, + None, + q.shape[1], + k.shape[1], + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + descale_o, + True, + ) + delta = delta_triton + elif BWD_MODE == "jingning": + delta_triton = attention_prefill_backward_triton_split_oneKernel_impl( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + softmax_scale, + alibi_slopes, + causal, + "bshd", + None, + None, + None, + None, + dropout_p, + philox_seed, + philox_offset, + False + ) + delta = delta_triton + else: + raise ValueError(f"Unknown bwd mode {BWD_MODE}") if DEBUG: - print("bwd outputs") + print("flash_attn_triton_amd.py::bwd outputs") print("dv:", dv, dv.shape) + if is_fp8(dv): + print("descale_dv:", descale_dv, descale_dv.shape if descale_dv is not None else None) print("dk:", dk, dk.shape) + if is_fp8(dk): + print("descale_dk:", descale_dk, descale_dk.shape if descale_dk is not None else None) print("dq:", dq, dq.shape) + if is_fp8(dq): + print("descale_dq:", descale_dq, descale_dq.shape if descale_dq is not None else None) return dq, dk, dv, delta def varlen_fwd( - q, - k, - v, - o, - cu_seqlens_q, - cu_seqlens_k, - seqused_k, - leftpad_k, - block_table_, - alibi_slopes,\ - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - zero_tensors, - causal, - window_size_left, - window_size_right, - softcap, - return_softmax, - gen_): + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + seqused_k: Optional[torch.Tensor], + leftpad_k: Optional[torch.Tensor], + block_table_: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + softmax_scale: float, + zero_tensors: bool , + causal: bool , + window_size_left: int, + window_size_right: int, + softcap: float, + return_softmax: bool, + gen_: Optional[torch.Tensor] = None, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_o: Optional[torch.Tensor] = None + ): if DEBUG: print() @@ -269,120 +394,135 @@ def varlen_fwd( print("window_size_left:", window_size_left) print("window_size_right:", window_size_right) print("gen_:", gen_) + print("descale_q:", descale_q, descale_q.shape if descale_q is not None else None) + print("descale_k:", descale_k, descale_k.shape if descale_k is not None else None) + print("descale_v:", descale_v, descale_v.shape if descale_v is not None else None) - if dropout_p != 0.0: - raise ValueError("dropout is not supported on AMD's Triton Backend yet") - - if o is None: - o = torch.empty_like(q) + if is_fp8(q): + assert out is not None, "fp8 output tensor should be passed in." + assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"For fp8, you need to pass descale factors for q, k and v" + else: + out = torch.zeros_like(q) if out is None else out.zero_() # Setup metadata metadata = MetaData(sm_scale=softmax_scale) if return_softmax: metadata.return_scores = True - metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) # set layout to "thd" and other metdata + metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) # set layout to "thd" and other metdata + assert metadata.layout is not None # get shapes - batch, nheads_q, nheads_k, head_size , seqlen_q, seqlen_k = get_shape_from_layout(q, k, metadata.layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) + batch, nheads_q, nheads_k, head_size , seqlen_q, seqlen_k = get_shapes_from_layout(q, k, metadata.layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) if causal: - metadata.need_causal() + metadata.need_causal(True) if alibi_slopes is not None: metadata.need_alibi(alibi_slopes, batch, nheads_q) - - if dropout_p > 0.0: - metadata.need_dropout(dropout_p, return_softmax) - + + # store rng state + metadata.need_dropout(dropout_p, return_softmax) + rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) # as_tensors uses the underlying data and doesnot cast + # Check arguments - metadata.check_args(q, k, v, o) - if o is None: - o = torch.empty_like(q, dtype=v.dtype) + metadata.check_args(q, k, v, out) + # call implementation if USE_REF: if DEBUG: print("Using reference implementation") - (output, - softmax_lse, - exp_scores, - _, - _, - _, - _) = attention_forward_pytorch_ref_impl( - q, - k, + softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( + q, + k, v, - metadata.sm_scale, + out, + metadata.sm_scale, + metadata.alibi_slopes, metadata.causal, - metadata.layout, - metadata.cu_seqlens_q, + metadata.layout, + metadata.cu_seqlens_q, metadata.cu_seqlens_k, - metadata.max_seqlens_q, + metadata.max_seqlens_q, metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, metadata.use_exp2) - o.copy_(output) + softmax_lse=softmax_lse_ref + sd_mask=sd_mask_ref else: if DEBUG: print("Using Triton implementation") - (_, - softmax_lse, - exp_scores, - _, - _, - _, - _, - _, - _) = attention_prefill_forward_triton_impl( - q, - k, - v, - o, - metadata.sm_scale, - metadata.alibi_slopes, - metadata.causal, - metadata.bias, - metadata.dropout_p, - metadata.layout, - metadata.cu_seqlens_q, + softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl( + q, + k, + v, + out, + metadata.sm_scale, + metadata.alibi_slopes, + metadata.causal, + None, + metadata.layout, + metadata.cu_seqlens_q, metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.return_scores, - metadata.use_exp2) + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.cache_seqlens, + metadata.cache_batch_idx, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + metadata.return_scores, + metadata.use_exp2, + descale_q, + descale_k, + descale_v, + descale_o) + softmax_lse=softmax_lse_triton + sd_mask=sd_mask_triton + if DEBUG: print("varlen_fwd outputs") - print("o:", o, o.shape) + print("out:", out, out.shape) print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("exp_scores:", exp_scores, exp_scores.shape if exp_scores is not None else None ) + print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None ) - return o, softmax_lse, exp_scores, None + return out, softmax_lse, sd_mask, rng_state def varlen_bwd( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_k, - alibi_slopes, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - zero_tensors, - causal, - window_size_left, - window_size_right, - softcap, - deterministic, - gen_, - rng_state, + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + alibi_slopes: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + softmax_scale: float, + zero_tensors: bool, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + deterministic: bool, + gen_ : Optional[torch.Tensor] = None, + rng_state: Optional[torch.Tensor] = None, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_o: Optional[torch.Tensor] = None, + descale_do: Optional[torch.Tensor] = None, + descale_dq: Optional[torch.Tensor] = None, + descale_dk: Optional[torch.Tensor] = None, + descale_dv: Optional[torch.Tensor] = None, ): if DEBUG: print() @@ -391,17 +531,17 @@ def varlen_bwd( print("q:", q, q.shape) print("k:", k, k.shape) print("v:", v, v.shape) + print("out:", out) print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("dq:", dq, dq.shape) - print("dk:", dk, dk.shape) - print("dv:", dv, dv.shape) + print("dq:", dq, dq.shape if dq is not None else None) + print("dk:", dk, dk.shape if dk is not None else None) + print("dv:", dv, dv.shape if dv is not None else None) print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape) print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape) print("alibi_slopes:", alibi_slopes) print("max_seqlen_q:", max_seqlen_q) print("max_seqlen_k:", max_seqlen_k) print("dropout_p:", dropout_p) - print("out:", out) print("softmax_scale:", softmax_scale) print("causal:", causal) print("window_size_left:", window_size_left) @@ -409,37 +549,52 @@ def varlen_bwd( print("deterministic:", deterministic) print("gen_:", gen_) print("rng_state:", rng_state) + print("descale_q:", descale_q, descale_q.shape if descale_q is not None else None) + print("descale_k:", descale_k, descale_k.shape if descale_k is not None else None) + print("descale_v:", descale_v, descale_v.shape if descale_v is not None else None) + print("descale_do:", descale_do, descale_do.shape if descale_do else None) + + dq = torch.zeros_like(q) if dq is None else dq.zero_() + dk = torch.zeros_like(k) if dk is None else dk.zero_() + dv = torch.zeros_like(v) if dv is None else dv.zero_() - if dropout_p != 0.0: - raise ValueError("dropout is not supported on AMD yet") + if rng_state is not None: + philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() + else: + philox_seed, philox_offset = None, None + # call implementation if USE_REF: if DEBUG: print("Using reference implementation") - dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl( + delta_ref = attention_backward_pytorch_ref_impl( dout, q, k, v, out, softmax_lse, + dq, + dk, + dv, softmax_scale, + alibi_slopes, causal, "thd", cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, False, ) - dq.copy_(dq_ref) - dk.copy_(dk_ref) - dv.copy_(dv_ref) delta = delta_ref else: if DEBUG: - print("Using Triton implementation") - dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl( + print("Using Triton implementation") + delta_triton = attention_prefill_backward_triton_split_impl( dout, q, k, @@ -457,7 +612,18 @@ def varlen_bwd( cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, False, + descale_q, + descale_k, + descale_v, + descale_o, + descale_do, + descale_dq, + descale_dk, + descale_dv, ) delta = delta_triton @@ -471,29 +637,54 @@ def varlen_bwd( return dq, dk, dv, delta def fwd_kvcache( - q, - k_cache, - v_cache, - k, - v, - cache_seqlens, - rotary_cos, - rotary_sin, - cache_batch_idx, - cache_leftpad, - block_table, - alibi_slopes, - out, - softmax_scale, - causal, - window_size_left, - window_size_right, - softcap, - rotary_interleaved, - num_splits): - - if out is None: - out = torch.empty_like(q) + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + cache_seqlens: Optional[Union[(int, torch.Tensor)]], + rotary_cos: Optional[torch.Tensor], + rotary_sin: Optional[torch.Tensor], + cache_batch_idx: Optional[torch.Tensor], + cache_leftpad: Optional[torch.Tensor], + block_table: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + out: Optional[torch.Tensor], + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + rotary_interleaved: bool, + num_splits: int + ): + + if DEBUG: + print() + print("flash_attn_triton_amd.py::fwd_kvcache inputs") + print("q:", q, q.shape) + print("k_cache:", k_cache, k_cache.shape) + print("v_cache:", v_cache, v_cache.shape) + print("k:", k, k.shape if k is not None else None) + print("v:", v, v.shape if v is not None else None) + print("cache_seqlens:", cache_seqlens ) + print("rotary_cos:",rotary_cos ) + print("rotary_sin:",rotary_sin) + print("cache_batch_idx:", cache_batch_idx) + print("cache_leftpad:", cache_leftpad) + print("block_table:", block_table) + print("alibi_slopes:", alibi_slopes) + print("out:", out) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("softcap:", softcap) + print("rotary_interleaved:", rotary_interleaved) + print("num_splits:", num_splits) + + # output + out = torch.zeros_like(q) if out is None else out.zero_() # fill metadata metadata = MetaData(sm_scale=softmax_scale) @@ -503,33 +694,99 @@ def fwd_kvcache( metadata.cache_seqlens = cache_seqlens metadata.cache_batch_idx = cache_batch_idx - if k is not None and v is not None: - metadata.new_kv = True - metadata.seqlen_new = k.shape[1] - metadata.k_new = k - metadata.v_new = v + k_new = k + v_new = v if causal: - metadata.need_causal() + metadata.need_causal(True) if alibi_slopes is not None: batch, _ , nheads_q, _= q.shape metadata.need_alibi(alibi_slopes, batch, nheads_q) + # rotary boolean + apply_rotary = torch.is_tensor(rotary_cos) and torch.is_tensor(rotary_sin) + if apply_rotary: + metadata.need_rotary(rotary_sin, rotary_cos, rotary_interleaved) + + # Rotary Embedding Implementation + if apply_rotary: + if metadata.causal: # NOTE: when support is added. Add `or metadata.local` + q_ro = apply_rotary_emb( + q, + metadata.rotary_cos, + metadata.rotary_sin, + seqlen_offsets=metadata.cache_seqlens, + interleaved=metadata.rotary_interleaved, + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + metadata.rotary_cos, + metadata.rotary_sin, + seqlen_offsets=metadata.cache_seqlens, + interleaved=metadata.rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=metadata.max_seqlens_q, + ) + k_ro = apply_rotary_emb( + k_new, + metadata.rotary_cos, + metadata.rotary_sin, + seqlen_offsets=metadata.cache_seqlens, + interleaved=metadata.rotary_interleaved, + ) + + q, k_new = q_ro.to(q.dtype), k_ro.to(q.dtype) + # launch kernel - # TODO: pass output as an arg. Maybe we are copying output which is causing slow down - output, softmax_lse = attention_decode_forward_triton_impl( - q, - k_cache, - v_cache, - metadata.sm_scale, - metadata.causal, - metadata.alibi_slopes, - metadata.layout, - metadata.cache_seqlens, - metadata.cache_batch_idx, - metadata.new_kv, - metadata.k_new, - metadata.v_new, - ) - return output, softmax_lse + DECODE_KERNEL= True # os.environ.get('DECODE_KERNEL', '0').lower() in ('1', 'true', 'yes') + if DECODE_KERNEL: + softmax_lse_triton = attention_decode_forward_triton_impl( + q, + k_cache, + v_cache, + k_new, + v_new, + out, + metadata.sm_scale, + metadata.causal, + metadata.alibi_slopes, + metadata.layout, + metadata.cache_seqlens, + metadata.cache_batch_idx, + ) + else: + softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl( + q, + k_cache, + v_cache, + out, + metadata.sm_scale, + metadata.alibi_slopes, + metadata.causal, + None, + metadata.layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.cache_seqlens, + metadata.cache_batch_idx, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + metadata.return_scores, + metadata.use_exp2, + None, + None, + None, + None) + softmax_lse = softmax_lse_triton + + if DEBUG: + print("out:", out, out.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + return out, softmax_lse diff --git a/flash_attn/flash_attn_triton_amd/interface_torch.py b/flash_attn/flash_attn_triton_amd/interface_torch.py deleted file mode 100644 index d4906606eda..00000000000 --- a/flash_attn/flash_attn_triton_amd/interface_torch.py +++ /dev/null @@ -1,97 +0,0 @@ -import torch -from .fwd_prefill import attention_prefill_forward_triton_impl -from .bwd_prefill import attention_prefill_backward_triton_impl -from .fwd_decode import attention_decode_forward_triton_impl - - -class _attention_prefill(torch.autograd.Function): - @staticmethod - def forward(ctx, q, k, v, o, metadata): - (output, - softmax_lse, - exp_scores, - grid, - head_size, - philox_seed, - philox_offset, - _, - _) = attention_prefill_forward_triton_impl( - q, - k, - v, - o, - metadata.sm_scale, - metadata.alibi_slopes, - metadata.causal, - metadata.bias, - metadata.dropout_p, - metadata.layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.return_scores, - metadata.use_exp2) - - ctx.save_for_backward(q, k, v, o, softmax_lse) - ctx.grid = grid - ctx.sm_scale = metadata.sm_scale - ctx.head_size = head_size - ctx.causal = metadata.causal - ctx.alibi_slopes = metadata.alibi_slopes - ctx.dropout_p = metadata.dropout_p - ctx.philox_seed = philox_seed - ctx.philox_offset = philox_offset - ctx.exp_scores = exp_scores - ctx.return_scores = metadata.return_scores - ctx.layout = metadata.layout - ctx.use_exp2 = metadata.use_exp2 - return output, softmax_lse, exp_scores - - @staticmethod - def backward(ctx, do, *args): - q, k, v, o, softmax_lse = ctx.saved_tensors - return attention_prefill_backward_triton_impl( - do, - q, - k, - v, - o, - softmax_lse, - None, - None, - None, - ctx.sm_scale, - ctx.alibi_slopes, - ctx.causal, - ctx.layout, - None, - None, - None, - None, - ctx.use_exp2 - ) - -attention_prefill = _attention_prefill.apply - - -class _attention_decode(torch.autograd.Function): - @staticmethod - def forward(ctx, q, k, v, metadata): - output, softmax_lse = attention_decode_forward_triton_impl( - q, - k, - v, - metadata.sm_scale, - metadata.causal, - metadata.alibi_slopes, - metadata.layout, - metadata.cache_seqlens, - metadata.cache_batch_idx, - metadata.new_kv, - metadata.k_new, - metadata.v_new, - ) - return output, softmax_lse - -attention_decode = _attention_decode.apply diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py index 9a6ab8dab28..58e2ae5fc7f 100644 --- a/flash_attn/flash_attn_triton_amd/test.py +++ b/flash_attn/flash_attn_triton_amd/test.py @@ -1,617 +1,348 @@ +import os +import glob +import shutil +import time import torch import pytest - -from .utils import MetaData, get_input_shapes, input_helper, varlen_input_helper, DEBUG -from .interface_torch import attention_prefill, attention_decode -from .fwd_ref import attention_forward_pytorch_ref_impl, compute_alibi_tensor_ref +import logging +import numpy as np +from pathlib import Path +from flash_attn import ( + flash_attn_func, + flash_attn_fp8_func, + flash_attn_kvpacked_func, + flash_attn_qkvpacked_func, + flash_attn_qkvpacked_fp8_func, + flash_attn_varlen_func, + flash_attn_varlen_fp8_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_qkvpacked_fp8_func +) + +from .utils import DEBUG, input_helper, arch_supports_fp8 +from .fwd_ref import attention_forward_pytorch_ref_impl from .fwd_prefill import attention_prefill_forward_triton_impl -from .bwd_prefill import attention_prefill_backward_triton_impl +from .bwd_prefill_split import attention_prefill_backward_triton_split_impl from .bwd_ref import attention_backward_pytorch_ref_impl -from .fwd_decode import dequantize_kv_fp16, quantize_kv_int4 + +# set print options +# torch.set_printoptions(linewidth=5e5, edgeitems=10, sci_mode=False) +# np.set_printoptions(linewidth=5000, threshold=1e4, suppress=True, precision=4) # defailt fp16 tolerance is ATOL, RTOL = 1e-5, 1e-3. See table https://pytorch.org/docs/stable/testing.html ATOL, RTOL = 1e-2, 1e-2 # old standard. maybe to lose. # ATOL, RTOL = 1e-3, 1e-3 # catchs fa mismatch issues # ATOL, RTOL = 1e-4, 1e-3 # to strict. there will be small diffs # ATOL, RTOL = 1e-5, 1e-3 # # default fp16. there will be small diffs +# ATOL_fp8, RTOL_fp8 = 1e-1, 1e-1 # to strict for larger tensors in fp8 +ATOL_fp8, RTOL_fp8 = 2.5e-1, 2.5e-1 # fp8 +# ATOL_fp8, RTOL_fp8 = 2e-2, 2e-2 # fp8 EQUAL_NAN = True -@pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ - (4, 48, 24, 1024, 1024, 64), - (1, 24, 6, 8192, 8192, 64), - (1, 4, 2, 16384, 16384, 128), - (2, 16, 4, 1020, 987, 128), - (2, 16, 4, 15498, 2, 128), - (2, 16, 2, 7, 16219, 64), - (4, 48, 12, 1, 1, 64), - (4, 48, 48, 1, 1, 128), - (4, 48, 24, 3, 3, 128), - (4, 48, 48, 1001, 990, 64), - (1, 8, 8, 8081, 7099, 64), - (1, 4, 4, 16330, 15989, 128), - (4, 4, 1, 1024, 1024, 33), - (4, 4, 2, 65, 1018, 65), - (4, 4, 4, 128, 128, 65), - (4, 4, 4, 113, 123, 1), -]) -@pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('use_alibi', [True, False]) -@pytest.mark.parametrize('layout', ['bshd', 'bhsd']) -def test_op_fwd_prefill(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, dtype=torch.float16): - torch.manual_seed(20) - q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout) - if causal: - input_metadata.need_causal() - - if use_alibi: - # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) - alibi_slopes = torch.tensor([2**(-8 / HQ * i) for i in range(1, HQ + 1)], dtype=torch.float32, - device="cuda").repeat(Z, 1) - input_metadata.need_alibi(alibi_slopes, Z, HQ) - else: - alibi_slopes = None - - o = torch.empty_like(q) - - # triton implementation - tri_out, _, _ = attention_prefill(q, k, v, o, input_metadata) - - # Transpose here if layout is bshd so we have same reference code for all layouts - if layout == 'bshd': - q = q.transpose(1, 2).clone() - k = k.transpose(1, 2).clone() - v = v.transpose(1, 2).clone() - # Replicate K and V if using MQA/GQA - if HQ != HK: - k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], - k.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(k.shape[0], -1, k.shape[2], k.shape[3]) - v = v.view(v.shape[0], v.shape[1], -1, v.shape[2], - v.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(v.shape[0], -1, v.shape[2], v.shape[3]) - - scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * input_metadata.sm_scale - if causal: - mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) - scores[:, :, mask == 0] = float("-inf") - if use_alibi: - scores += compute_alibi_tensor_ref(alibi_slopes, N_CTX_Q, N_CTX_K) - - p = torch.softmax(scores, dim=-1) - if causal: - # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into - # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix - # this by converting the NaNs to 0s, which is what they should be out of the softmax. - nan_mask = torch.isnan(p) - p[nan_mask == 1] = 0 - ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) - # compare - if layout == 'bshd': - ref_out = ref_out.transpose(1, 2).clone() - torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) - - -@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ - (4, 48, 1024, 1024, 64), - (4, 12, 8192, 8192, 64), - (2, 4, 16384, 16384, 128), - (2, 16, 15498, 2, 128), - (2, 4, 7, 16219, 64), - (4, 48, 1, 1, 64), - (4, 48, 1, 1, 128), - (4, 48, 3, 3, 128), - (4, 48, 1001, 990, 64), - (1, 8, 8081, 7099, 64), - (1, 8, 16330, 15989, 128), - (4, 4, 1024, 1024, 33), - (4, 4, 65, 1019, 65), - (4, 4, 128, 128, 65), - # TODO: This config fails. Disabled until triaged and fixed. - # (2, 16, 1020, 987, 128), - # (4, 4, 113, 123, 1), -]) +@pytest.mark.parametrize( + "BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", + [ + (1, 1, 1, 1, 1, 1), + (1, 1, 1, 2, 4, 16), + (1, 2, 2, 2, 4, 16), + (1, 4, 1, 2, 4, 16), + (1, 4, 2, 2, 4, 16), + (1, 1, 1, 4, 2, 16), + (1, 1, 1, 4, 4, 16), + (1, 2, 2, 4, 4, 16), + (2, 1, 1, 4, 4, 16), + (2, 2, 2, 4, 4, 16), + (1, 1, 1, 128, 64, 16), + (2, 2, 2, 2, 128, 1), + (2, 3, 3, 2, 128, 16), + (3, 2, 2, 256, 512, 16), + (3, 3, 3, 128, 128, 64), + (2, 4, 4, 1024, 1024, 64), + (4, 6, 6, 108, 256, 224), + (4, 8, 8, 2048, 2048, 128), + (4, 16, 16, 4096, 4096, 64), + (2, 4, 4, 8192, 8192, 32), + # fa configs + (4, 6, 1, 113, 203, 256), + (4, 6, 1, 128, 217, 256), + (4, 6, 2, 113, 211, 128), + (4, 6, 2, 108, 256, 128), + (4, 6, 1, 256, 512, 64), + (4, 6, 1, 512, 256, 64), + (4, 6, 2, 1024, 1024, 32), + (4, 6, 2, 1023, 1024, 32), + (4, 6, 6, 1024, 1023, 32), + (4, 6, 6, 2048, 2048, 32), + ], +) @pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('use_bias', [True]) -def test_op_fwd_prefill_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=torch.float16): - torch.manual_seed(20) - sm_scale = D_HEAD**-0.5 - input_metadata = MetaData(sm_scale=sm_scale) - q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout='bhsd') - if causal: - input_metadata.need_causal() - if use_bias: - bias = torch.randn((1, H, N_CTX_Q, N_CTX_K), dtype=torch.float32, device="cuda") - input_metadata.need_bias(bias, Z, H, N_CTX_Q, N_CTX_K) - else: - bias = None - o = torch.empty_like(q) - - # triton implementation - tri_out, _, _ = attention_prefill(q, k, v, o, input_metadata) - # reference implementation:171 - - scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * sm_scale - if causal: - mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) - scores[:, :, mask == 0] = float("-inf") - if use_bias: - scores += input_metadata.bias - p = torch.softmax(scores, dim=-1) - if causal: - # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into - # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix - # this by converting the NaNs to 0s, which is what they should be out of the softmax. - nan_mask = torch.isnan(p) - p[nan_mask == 1] = 0 - ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) - # compare - torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) - - -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ - (4, 48, 8192, 64), - (4, 48, 256, 64), - (4, 48, 512, 64), - (4, 48, 1024, 64), - (8, 48, 4096, 64), - (4, 48, 8192, 64), - (4, 48, 128, 128), - (4, 48, 4096, 128), - (4, 48, 16384, 128), - (4, 16, 1024, 128), - (4, 16, 8192, 128), - (32, 48, 8192, 128) - ] - ) -@pytest.mark.parametrize('causal', [True, False]) -def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): - - q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, D_HEAD, dtype) - - tri_out = torch.empty_like(q) - ref_out = torch.empty_like(q) - - for i in range(0, input_metadata.num_contexts): - start_q, start_k = input_metadata.cu_seqlens_q[i], input_metadata.cu_seqlens_k[i] - end_q, end_k = input_metadata.cu_seqlens_q[i + 1], input_metadata.cu_seqlens_k[i + 1] - scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], k[start_k:end_k]).float() - p = torch.softmax(scores * input_metadata.sm_scale, dim=-1).half() - ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v[start_k:end_k]) - attention_prefill(q, k, v, tri_out, input_metadata) - torch.testing.assert_close(ref_out, tri_out, atol=ATOL, rtol=RTOL) - - -@pytest.mark.parametrize('Z, HQ, HK, N_CTX, D_HEAD', [(2, 48, 24, 128, 64), (4, 48, 12, 256, 64), (4, 48, 4, 512, 64), - (4, 48, 2, 1024, 64), (8, 48, 6, 4096, 64), (4, 48, 8, 16384, 64), - (4, 64, 16, 128, 128), (4, 64, 4, 4096, 128), - (4, 64, 8, 16384, 128), (4, 16, 4, 1024, 128), - (4, 16, 2, 8192, 128), (32, 128, 32, 8192, 128)]) -@pytest.mark.parametrize('causal', [False]) -def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16): - q, k, v, input_metadata = varlen_input_helper(Z, HQ, HK, N_CTX, N_CTX, D_HEAD, dtype) - ref_out = torch.empty_like(q) - tri_out = torch.empty_like(q) - # Make KV look like HQ/HK "groups" of HK. Later, we will reshape so the - # size aligns with Q. - k_ref = k.view(k.shape[0], k.shape[1], 1, k.shape[2]).expand(-1, -1, HQ // HK, -1) - v_ref = v.view(v.shape[0], v.shape[1], 1, v.shape[2]).expand(-1, -1, HQ // HK, -1) - for i in range(0, input_metadata.num_contexts): - start_q, start_k = input_metadata.cu_seqlens_q[i], input_metadata.cu_seqlens_k[i] - end_q, end_k = input_metadata.cu_seqlens_q[i + 1], input_metadata.cu_seqlens_k[i + 1] - k_curr = k_ref[start_k:end_k] - k_curr = k_curr.reshape(k_curr.shape[0], -1, k_curr.shape[3]) - v_curr = v_ref[start_k:end_k] - v_curr = v_curr.reshape(v_curr.shape[0], -1, v_curr.shape[3]) - scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], k_curr).float() - p = torch.softmax(scores * input_metadata.sm_scale, dim=-1).half() - ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v_curr) - attention_prefill(q, k, v, tri_out, input_metadata) - torch.testing.assert_close(ref_out, tri_out, atol=ATOL, rtol=RTOL) - - -@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ - # smallest config test - (1, 1, 16, 16, 64), # pass on new # fail on old - (1, 1, 32, 32, 64), # pass on new # fail on old - (1, 1, 64, 64, 16), # pass # smallest head_size = 16 - (1, 1, 64, 64, 64), # pass # smallest seq len seems to be 64 - (1, 1, 128, 128, 64), # pass - (1, 1, 256, 256, 64), # pass - (1, 1, 512, 512, 64), # pass - # failing FA - (1, 1, 256, 512, 16), - # old tests that work - (4, 48, 1024, 1024, 64), # pass - (4, 48, 2048, 2048, 64), # pass - (2, 48, 4096, 4096, 64), # pass - (1, 16, 1024, 1024, 64), # pass - (1, 16, 1024, 1024, 128), # pass - # old tests that were commented out - # (1, 16, 8192, 8192, 63), - # (1, 16, 1022, 1022, 64), -]) -# @pytest.mark.parametrize('torch_sdpa_test', [False, True]) -@pytest.mark.parametrize('torch_sdpa_test', [False]) -# @pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('causal', [False]) -# @pytest.mark.parametrize('use_alibi', [False, True]) -@pytest.mark.parametrize('use_alibi', [False]) -def test_op_bwd(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, torch_sdpa_test, use_alibi, dtype=torch.float16): - torch.manual_seed(20) - - DEBUG_INPUT = False - - # seqlens - seqlen_q = N_CTX_Q - seqlen_k = N_CTX_K - - # setup up metadata - if DEBUG_INPUT: - sm_scale = 1 - else: - sm_scale = D_HEAD**-0.5 - input_metadata = MetaData(sm_scale=sm_scale) - input_metadata.max_seqlens_q = seqlen_q - input_metadata.max_seqlens_k = seqlen_k - input_metadata.layout = "bhsd" - - dropout_p = 0 - if DEBUG_INPUT: - q = torch.arange(seqlen_q, dtype=dtype, device="cuda").view(1, 1, seqlen_q, 1).expand(Z, H, seqlen_q, D_HEAD).requires_grad_() - k = torch.arange(seqlen_k, dtype=dtype, device="cuda").view(1, 1, seqlen_k, 1).expand(Z, H, seqlen_k, D_HEAD).requires_grad_() - v = torch.arange(seqlen_k, dtype=dtype, device="cuda").view(1, 1, seqlen_k, 1).expand(Z, H, seqlen_k, D_HEAD).requires_grad_() - o = torch.zeros_like(q) - else: - # Generate random inputs - q = torch.randn(Z, H, N_CTX_Q, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) - k = torch.randn(Z, H, N_CTX_K, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) - v = torch.randn(Z, H, N_CTX_K, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) - o = torch.empty_like(q) - - if causal: - input_metadata.need_causal() - - if use_alibi and not torch_sdpa_test: - # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) - alibi_slopes = torch.tensor([2**(-8 / H * i) for i in range(1, H + 1)], dtype=torch.float32, - device="cuda").repeat(Z, 1) - input_metadata.need_alibi(alibi_slopes, Z, H) - - if DEBUG_INPUT: - dout = torch.ones_like(q) - else: - dout = torch.randn_like(q) - - # reference implementation - if torch_sdpa_test: - ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, dropout_p=dropout_p, - is_causal=causal, scale=sm_scale, - dropout_mask=None) - ref_out.backward(dout.to(device=ref_out.device, dtype=ref_out.dtype)) - ref_dv, v.grad = v.grad.clone(), None - ref_dk, k.grad = k.grad.clone(), None - ref_dq, q.grad = q.grad.clone(), None - else: - M = torch.tril(torch.ones((seqlen_q, seqlen_k), device="cuda")) - p = torch.matmul(q, k.transpose(2, 3)) * sm_scale - if use_alibi: - p += compute_alibi_tensor_ref(alibi_slopes, N_CTX_Q, N_CTX_K) - if causal: - p[:, :, M == 0] = float("-inf") - - p = torch.softmax(p.float(), dim=-1).type(dtype=p.dtype) - ref_out = torch.matmul(p, v) - ref_out.backward(dout) - ref_dv, v.grad = v.grad.clone(), None - ref_dk, k.grad = k.grad.clone(), None - ref_dq, q.grad = q.grad.clone(), None - - # # triton implementation - tri_out, _, _ = attention_prefill(q, k, v, o, input_metadata) - tri_out.backward(dout) - tri_dv, v.grad = v.grad.clone(), None - tri_dk, k.grad = k.grad.clone(), None - tri_dq, q.grad = q.grad.clone(), None - # compare - if DEBUG: - print("tri_out:", tri_out) - print("ref_out:",ref_out ) - torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0) - - # The current block size for MI200 series is 64x64. This results in - # larger differences in float results due to rounding. - if dtype == torch.bfloat16: - ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) - if dtype == torch.float32: - ATOL = 1e-3 * max(1.0, (seqlen_q + D_HEAD) / 64.0) - else: - ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) - - RTOL = 0 - - if DEBUG: - print("ref_dv:", ref_dv) - print("tri_dv:", tri_dv) - print("ref_dk:", ref_dk) - print("tri_dk:", tri_dk) - print("ref_dq:", ref_dq) - print("tri_dq:", tri_dq) - - torch.testing.assert_close(ref_dv, tri_dv, atol=ATOL, rtol=RTOL) - torch.testing.assert_close(ref_dk, tri_dk, atol=ATOL, rtol=RTOL) - torch.testing.assert_close(ref_dq, tri_dq, atol=ATOL, rtol=RTOL) - - -@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ - (1, 1, 1, 1, 1), - (1, 1, 2, 4, 16), - (1, 1, 4, 2, 16), - (1, 1, 4, 4, 16), - (1, 2, 4, 4, 16), - (2, 1, 4, 4, 16), - (2, 2, 4, 4, 16), - (1, 1, 128, 64, 16), - (2, 2, 2, 128, 1), - (2, 3, 2, 128, 16), - (3, 2, 256, 512, 16), - (3, 3, 128, 128, 64), - (2, 4, 1024, 1024, 64), - (4, 6, 108, 256, 224), - (4, 8, 2048, 2048, 128), - (4, 16, 4096, 4096, 64), - (2, 4, 8192, 8192, 32), - # # fa configs - (4, 6, 113, 203, 256), - (4, 6, 128, 217, 256), - (4, 6, 113, 211, 128), - (4, 6, 108, 256, 128), - (4, 6, 256, 512, 64), - (4, 6, 512, 256, 64), - (4, 6, 1024, 1024, 32), - (4, 6, 1023, 1024, 32), - (4, 6, 1024, 1023, 32), - (4, 6, 2048, 2048, 32), -]) -@pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('return_scores', [False]) -@pytest.mark.parametrize('layout', ["bhsd", "bshd", "thd"]) +@pytest.mark.parametrize('dropout_p', [0.0]) +@pytest.mark.parametrize('alibi_slopes', [None]) +@pytest.mark.parametrize('layout', ["bshd", "thd"]) +@pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize('use_exp2', [True, False]) # works when use_exp2 is false @pytest.mark.parametrize('DEBUG_INPUT', [False]) # NOTE: debug input can overflow when the tensors are large. Just use to figure out issues -def test_op_prefill_fwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, return_scores, layout, use_exp2, DEBUG_INPUT): - dtype = torch.float16 - torch.manual_seed(0) - alibi_slopes = None - dropout_p = 0.0 +def test_op_prefill_fwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, alibi_slopes, layout, dtype, use_exp2, DEBUG_INPUT): + torch.manual_seed(42) device = "cuda" - if layout == "thd": - q, k, v, metadata = varlen_input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - else: - q, k, v, metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device=device, DEBUG_INPUT=DEBUG_INPUT) - if DEBUG_INPUT: - output_triton = torch.zeros_like(q).contiguous() - else: - output_triton = torch.empty_like(q) + q, k, v, do, metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, dtype, layout=layout, device=device) + + if DEBUG: + if HQ // HK != 1: + print("MQA/GQA") + else: + print("MHA") # update metadata metadata.use_exp2 = use_exp2 if causal: - metadata.need_causal() + metadata.need_causal(True) # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - if return_scores: - metadata.return_scores = True + metadata.need_dropout(dropout_p) + # call Triton's forward implementation directly - ( output_triton, - softmax_lse_triton, - exp_scores_triton, - _, - _, - _, - _, - _, - _) = attention_prefill_forward_triton_impl( - q, - k, - v, - output_triton, + q_triton = q.clone() + k_triton = k.clone() + v_triton = v.clone() + o_triton = torch.zeros_like(q).contiguous() if DEBUG_INPUT else torch.empty_like(q) + softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl( + q_triton, + k_triton, + v_triton, + o_triton, metadata.sm_scale, metadata.alibi_slopes, metadata.causal, metadata.bias, - metadata.dropout_p, metadata.layout, metadata.cu_seqlens_q, metadata.cu_seqlens_k, metadata.max_seqlens_q, - metadata.max_seqlens_k, + metadata.max_seqlens_k, + metadata.cache_seqlens, + metadata.cache_batch_idx, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, metadata.return_scores, - metadata.use_exp2) - - ( - output_ref, - softmax_lse_ref, - exp_scores_ref, - softmax_ref, - attention_shifted_scaled_scores_ref, - attention_scaled_scores_ref, - attention_scores_ref, - ) = attention_forward_pytorch_ref_impl( - q.clone(), - k.clone(), - v.clone(), - metadata.sm_scale, + metadata.use_exp2, + None, + None, + None, + None) + + # ref forward + q_ref = q.clone() + k_ref = k.clone() + v_ref = v.clone() + o_ref = torch.zeros_like(q).contiguous() if DEBUG_INPUT else torch.empty_like(q) + softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( + q_ref, + k_ref, + v_ref, + o_ref, + metadata.sm_scale, + metadata.alibi_slopes, causal, layout, metadata.cu_seqlens_q, metadata.cu_seqlens_k, metadata.max_seqlens_q, metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, use_exp2 ) + if DEBUG: + print() + print("Compare Triton Impl with refernce Pytorch Impl") + + # this can be set to true manually or when using dropout + if metadata.return_scores: + if DEBUG: + print("sd_mask_triton:", sd_mask_triton, sd_mask_triton.shape) + print("sd_mask_ref:", sd_mask_ref, sd_mask_ref.shape) + torch.testing.assert_close(sd_mask_triton.to(sd_mask_ref.dtype), sd_mask_ref, atol=ATOL, rtol=RTOL) + if DEBUG: print("softmax_lse_triton:", softmax_lse_triton, softmax_lse_triton.shape) print("softmax_lse_ref:", softmax_lse_ref, softmax_lse_ref.shape) torch.testing.assert_close(softmax_lse_triton, softmax_lse_ref, atol=ATOL, rtol=RTOL) - - if layout != "thd": - # use trick with lse to get the softmax. you need the scores but is it - softmax_triton = torch.exp(attention_scaled_scores_ref - softmax_lse_triton.unsqueeze(-1)) - if DEBUG: - print("attention_scaled_scores_ref:", attention_scaled_scores_ref, attention_scaled_scores_ref.shape) - print("softmax_lse_triton:", softmax_lse_triton, softmax_lse_triton.shape) - print("softmax_triton:", softmax_triton, softmax_triton.shape) - print("softmax_ref:", softmax_ref, softmax_ref.shape) - torch.testing.assert_close(softmax_triton, softmax_ref, atol=ATOL, rtol=RTOL) if DEBUG: - print("output_triton:", output_triton, output_triton.shape) - print("output_ref:", output_ref, output_ref.shape) - torch.testing.assert_close(output_triton, output_ref, atol=ATOL, rtol=RTOL) - - - # compare with pytorch expect thd and causal impl is different - if False and layout in ["bhsd", "bshd"] and not causal: - out_pytorch, softmax_pytorch = torch.ops.aten._scaled_dot_product_attention_math( - q.transpose(1, 2) if layout == "bshd" else q , - k.transpose(1, 2) if layout == "bshd" else k, - v.transpose(1, 2) if layout == "bshd" else v, - dropout_p=dropout_p, - is_causal=causal, scale=metadata.sm_scale, - dropout_mask=None) - out_pytorch = out_pytorch.transpose(1, 2) if layout == "bshd" else out_pytorch - - if DEBUG: - print("o:", output_triton, output_triton.shape) - print("out_pytorch:", out_pytorch, out_pytorch.shape) - torch.testing.assert_close(output_triton, out_pytorch, atol=ATOL, rtol=RTOL) - - # compare with pytorch output - if DEBUG: - print("softmax_triton:", softmax_triton, softmax_triton.shape) - print("softmax_pytorch:", softmax_pytorch, softmax_pytorch.shape) - torch.testing.assert_close(softmax_triton, softmax_pytorch.to(torch.float32), atol=ATOL, rtol=RTOL) - - -@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ - (1, 1, 1, 1, 1), - (1, 1, 4, 4, 4), - (2, 1, 4, 4, 16), - (1, 2, 4, 4, 16), - (2, 2, 4, 4, 16), - (1, 1, 4, 4, 16), - (2, 1, 4, 4 , 16), - (4, 6, 8, 8 , 16), - (1, 1, 4, 4, 32), - (1, 1, 16, 16, 16), - (1, 1, 32, 32, 16), - (1, 1, 64, 64, 16), - (1, 1, 64, 64, 64), - (1, 1, 64, 128, 32), - (1, 1, 128, 128, 64), - (1, 1, 128, 256, 45), - (1, 1, 113, 203, 192), - (1, 1, 256, 256, 64), - (1, 1, 256, 512, 16), - (1, 1, 512, 512, 64), - (1, 1, 1024, 1024, 64), + print("output_triton:", o_triton, o_triton.shape) + print("output_ref:", o_ref, o_ref.shape) + torch.testing.assert_close(o_triton, o_ref, atol=ATOL, rtol=RTOL) + +@pytest.mark.parametrize( + "BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", [ + (1, 1, 1, 1, 1, 1), + (1, 1, 1, 4, 4, 4), + (2, 1, 1, 4, 4, 16), + (1, 2, 2, 4, 4, 16), + (1, 4, 1, 2, 4, 16), + (1, 8, 1, 2, 4, 16), + (1, 16, 1, 2, 4, 16), + (1, 32, 1, 2, 4, 16), + (1, 64, 1, 2, 4, 16), + (1, 4, 2, 2, 4, 16), + (2, 2, 2, 4, 4, 16), + (1, 1, 1, 4, 4, 16), + (2, 1, 1, 4, 4 , 16), + (4, 6, 6, 8, 8 , 16), + (1, 1, 1, 4, 4, 32), + (1, 1, 1, 16, 16, 16), + (1, 1, 1, 32, 32, 16), + (1, 1, 1, 64, 64, 16), + (1, 1, 1, 64, 64, 16), + (1, 1, 1, 64, 128, 16), + (1, 1, 1, 64, 64, 32), + (1, 1, 1, 64, 128, 32), + (1, 1, 1, 128, 128, 64), + (1, 1, 1, 128, 256, 45), + (1, 1, 1, 113, 203, 192), + (1, 1, 1, 256, 256, 64), + (1, 1, 1, 256, 512, 16), + (1, 1, 1, 512, 512, 64), + (1, 1, 1, 1024, 1024, 64), # fa configs - (2, 2, 128, 128, 65), - (2, 2, 128, 128, 224), - (4, 6, 108, 256, 224), - (1, 1, 256, 512, 16), + (2, 2, 2, 128, 128, 65), + (2, 2, 2, 128, 128, 224), + (4, 6, 6, 108, 256, 224), + (1, 1, 1, 256, 512, 16), # old tests that work - (4, 48, 1024, 1024, 73), - (4, 48, 1024, 1024, 64), - (4, 48, 2048, 2048, 64), - (1, 24, 4096, 4096, 64), - (1, 16, 1024, 1024, 64), - (1, 16, 1024, 1024, 128), + (4, 48, 6, 1024, 1024, 64), + (4, 48, 12, 2048, 1024, 64), + (4, 48, 24, 1024, 1024, 64), + (4, 48, 48, 1024, 1024, 64), + (4, 48, 48, 1024, 1024, 73), + (4, 48, 48, 2048, 2048, 64), + (1, 24, 24, 4096, 4096, 64), + (1, 16, 16, 1024, 1024, 64), + (1, 16, 16, 1024, 1024, 128), + # testcase new + # seqlen q == k + (1, 1, 1, 2, 2, 2), # small enough to debug + (1, 1, 1, 128, 128, 32), # only one block + (1, 1, 1, 127, 127, 32), # only one block but with masking + (1, 1, 1, 129, 129, 1), # two blocks with 2nd block small enough to debug + (1, 1, 1, 350, 350, 1), # two blocks with 2nd block small enough to debug + (1, 1, 1, 350, 350, 68), # generic masking on q, k and head + (4, 1, 1, 512, 512, 128), # batch > 1 + (4, 8, 2, 512, 512, 128), # GQA + (4, 8, 2, 512, 512, 68), # non-power-of-2 head_dim + (4, 8, 2, 500, 500, 68), # comprehensive case for seqlen q == k + # seqlen q > k + (1, 1, 1, 64, 32, 8), # seqlen_q > seqlen_k + (1, 1, 1, 192, 128, 32), # seqlen_q > seqlen_k + (4, 8, 2, 1024, 512, 68), # seqlen_q < seqlen_k + (1, 1, 1, 729, 516, 68), # seqlen_q > seqlen_k + (16, 16, 4, 2753, 1528, 68), # a comprehensive seqlen_q > seqlen_k + # seqlen q < k + (1, 1, 1, 32, 64, 8), # seqlen_q > seqlen_k + (1, 1, 1, 128, 192, 32), # seqlen_q < seqlen_k + (4, 8, 2, 512, 1024, 68), # seqlen_q < seqlen_k + (1, 1, 1, 200, 413, 1), # seqlen_q < seqlen_k + (1, 1, 1, 782, 1546, 1), # seqlen_q < seqlen_k + (16, 16, 4, 1528, 2753, 68), # a comprehensive seqlen_q < seqlen_k ]) @pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('dropout_p', [0.0]) +@pytest.mark.parametrize('alibi_slopes', [None]) +@pytest.mark.parametrize('layout', ["bshd", "thd"]) +@pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize('use_exp2', [False]) # FIXME: using exp2 causes issue when used with causal -@pytest.mark.parametrize('layout', ["bhsd", "bshd", "thd"]) -@pytest.mark.parametrize('sequence_parallel', [True, False]) -@pytest.mark.parametrize('DEBUG_INPUT', [False]) # debug output causes nans in both new and old backend -def test_op_prefill_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, layout, sequence_parallel, DEBUG_INPUT): - dtype = torch.float16 - torch.manual_seed(20) # seed from test_op_bwd +@pytest.mark.parametrize('DEBUG_INPUT', [False]) # debug output causes nans on larger tensors +def test_op_prefill_bwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, alibi_slopes, layout, dtype, use_exp2, DEBUG_INPUT): + torch.manual_seed(20) + device="cuda" - alibi_slopes = None - if layout == "thd": - q, k, v, metadata = varlen_input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, DEBUG_INPUT=DEBUG_INPUT) - else: - q, k, v, metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, DEBUG_INPUT=DEBUG_INPUT) - if DEBUG_INPUT: - do = torch.ones_like(q).contiguous() - else: - do = torch.randn_like(q) + # gen inputs + q, k, v, do, metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, dtype, layout=layout, device=device) + + # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that + metadata.need_dropout(dropout_p) # =============================================== Reference ============================================================== + # fwd q_ref = q.clone() k_ref = k.clone() - v_ref = v.clone() - ( - o_ref, - softmax_lse_ref, - _, - _, - _, - _, - _, - ) = attention_forward_pytorch_ref_impl( + v_ref = v.clone() + output_ref = torch.zeros_like(q).contiguous() if DEBUG_INPUT else torch.empty_like(q) + softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( q_ref, k_ref, v_ref, - metadata.sm_scale, + output_ref, + metadata.sm_scale, + metadata.alibi_slopes, causal, layout, metadata.cu_seqlens_q, metadata.cu_seqlens_k, metadata.max_seqlens_q, metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, use_exp2 ) - dq = torch.zeros_like(q, dtype=q.dtype) # NOTE: the kernel does inplace accumlation on dq so dq has to be zeros - if DEBUG_INPUT: - dk = torch.zeros_like(k, dtype=k.dtype) - dv = torch.zeros_like(v, dtype=v.dtype) - else: - dk = torch.empty_like(k, dtype=k.dtype) - dv = torch.empty_like(v, dtype=v.dtype) - + # bwd do_ref = do.clone() - dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl( + dq_ref = torch.zeros_like(q).contiguous() if DEBUG_INPUT else torch.empty_like(q) + dk_ref = torch.zeros_like(k).contiguous() if DEBUG_INPUT else torch.empty_like(k) + dv_ref = torch.zeros_like(v).contiguous() if DEBUG_INPUT else torch.empty_like(v) + delta_ref = attention_backward_pytorch_ref_impl( do_ref, q_ref, k_ref, v_ref, - o_ref, + output_ref, softmax_lse_ref, + dq_ref, + dk_ref, + dv_ref, metadata.sm_scale, + metadata.alibi_slopes, causal, layout, metadata.cu_seqlens_q, metadata.cu_seqlens_k, metadata.max_seqlens_q, metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, use_exp2 ) # =============================================== Triton ============================================================== - o = o_ref.clone().contiguous() - softmax_lse = softmax_lse_ref.clone().contiguous() - dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl( - do, - q, - k, - v, - o, - softmax_lse, - dq, - dk, - dv, + do_triton = do.clone() + q_triton = q.clone() + k_triton = k.clone() + v_triton = v.clone() + o_triton = output_ref.clone().contiguous() + softmax_lse_triton = softmax_lse_ref.clone().contiguous() + dq_triton = torch.zeros_like(q_triton, dtype=q.dtype) # NOTE: the kernel does inplace accumlation on dq so dq has to be zeros + dk_triton = torch.zeros_like(k_triton, dtype=k.dtype) if DEBUG_INPUT else torch.empty_like(k_triton, dtype=k.dtype) + dv_triton = torch.zeros_like(v_triton, dtype=v.dtype) if DEBUG_INPUT else torch.empty_like(v_triton, dtype=v.dtype) + delta_triton = attention_prefill_backward_triton_split_impl( + do_triton, + q_triton, + k_triton, + v_triton, + o_triton, + softmax_lse_triton, + dq_triton, + dk_triton, + dv_triton, metadata.sm_scale, alibi_slopes, causal, @@ -620,8 +351,18 @@ def test_op_prefill_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, l metadata.cu_seqlens_k, metadata.max_seqlens_q, metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, use_exp2, - sequence_parallel=sequence_parallel + None, + None, + None, + None, + None, + None, + None, + None, ) # =============================================== Check ============================================================== @@ -647,78 +388,545 @@ def test_op_prefill_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, l print("dq_ref:", dq_ref, dq_ref.shape) torch.testing.assert_close(dq_triton, dq_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN) +def fp8_assert_close(tensor_a, tensor_b, atol=ATOL_fp8, rtol=RTOL_fp8, max_diff_percentage=0.5): + """Assert tensors are close with tolerance for small percentage of elements""" + # standard comparison + abs_diff = torch.abs(tensor_a - tensor_b) + rel_diff = abs_diff / torch.abs(tensor_b.clamp(min=1e-6)) + + # calculate elements that exceed tolerance + abs_check = abs_diff > atol + rel_check = rel_diff > rtol + failed_check = torch.logical_and(abs_check, rel_check) + + # calculate percentage of failed elements + failed_percentage = failed_check.sum().item() / failed_check.numel() * 100 + + # if percentage is small enough, test passes + if failed_percentage <= max_diff_percentage: + return True + + # Otherwise, provide diagnostic information + max_abs_idx = torch.argmax(abs_diff).item() + max_rel_idx = torch.argmax(rel_diff).item() + + flat_to_idx = lambda flat_idx, shape: np.unravel_index(flat_idx, shape) + + max_abs_pos = flat_to_idx(max_abs_idx, tensor_a.shape) + max_rel_pos = flat_to_idx(max_rel_idx, tensor_a.shape) + + max_abs_diff = abs_diff.flatten()[max_abs_idx].item() + max_rel_diff = rel_diff.flatten()[max_rel_idx].item() + + raise AssertionError( + f"Tensors not close enough! {failed_percentage:.6f}% elements exceed tolerance.\n" + f"Greatest absolute difference: {max_abs_diff} at index {max_abs_pos} (up to {atol} allowed)\n" + f"Greatest relative difference: {max_rel_diff} at index {max_rel_pos} (up to {rtol} allowed)" + ) + +@pytest.mark.parametrize( + "Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", + [ + # seqlen q == k + (1, 1, 1, 1, 1, 1), + (1, 1, 1, 2, 2, 2), # small enough to debug + (1, 1, 1, 4, 4, 16), + (1, 2, 2, 4, 4, 16), + (2, 1, 1, 4, 4, 16), + (2, 2, 2, 4, 4, 16), + (1, 1, 1, 128, 128, 32), # only one block + (3, 3, 3, 128, 128, 64), + (1, 1, 1, 127, 127, 32), # only one block but with masking + # (1, 1, 1, 129, 129, 1), # two blocks with 2nd block small enough to debug # fails + (1, 2, 2, 129, 129, 32), # two blocks with 2nd block small enough to debug + (1, 1, 1, 350, 350, 32), # two blocks with 2nd block small enough to debug + (1, 1, 1, 350, 350, 68), # generic masking on q, k and head + (4, 1, 1, 512, 512, 128), # batch > 1 + (4, 2, 2, 512, 512, 128), + (4, 2, 2, 512, 512, 68), + (4, 2, 2, 500, 500, 68), + (2, 4, 4, 1024, 1024, 64), + (4, 8, 8, 2048, 2048, 128), + (2, 8, 8, 4096, 4096, 64), + (2, 4, 4, 8192, 8192, 32), + # seqlen q > k + (1, 1, 1, 4, 2, 16), + (1, 1, 1, 64, 32, 8), + (1, 1, 1, 128, 64, 16), + (1, 1, 1, 192, 128, 32), + (1, 2, 2, 1024, 512, 68), + (1, 4, 4, 729, 516, 68), + (2, 4, 4, 2753, 1528, 68), # a comprehensive seqlen_q > seqlen_k + # seqlen q < k + (1, 1, 1, 2, 4, 16), + (1, 2, 2, 2, 4, 16), + (1, 4, 1, 2, 4, 16), + (1, 4, 2, 2, 4, 16), + (2, 2, 2, 2, 128, 1), + (2, 3, 3, 2, 128, 16), + (1, 1, 1, 32, 64, 8), + (1, 1, 1, 128, 192, 32), + (4, 6, 6, 108, 256, 32), + (3, 2, 2, 256, 512, 16), + (2, 2, 2, 512, 1024, 68), + (1, 1, 1, 200, 413, 32), + (1, 1, 1, 782, 1546, 32), + # gqa/mqa # mismatch issue on varlen + (4, 8, 2, 500, 500, 68), + (4, 8, 2, 512, 512, 68), + (4, 8, 2, 512, 512, 128), + (4, 8, 2, 512, 1024, 68), + (4, 8, 2, 1024, 512, 64), + (4, 16, 4, 1528, 2753, 68), + # fa configs + (2, 4, 1, 113, 203, 64), + (2, 4, 2, 128, 217, 64), + (2, 6, 2, 113, 211, 128), + (2, 6, 2, 108, 256, 128), + (2, 6, 2, 256, 512, 64), + (2, 6, 2, 512, 256, 64), + (2, 6, 2, 1024, 1024, 32), + (2, 6, 2, 1023, 1024, 32), + (2, 6, 6, 1024, 1023, 32), + (2, 6, 6, 2048, 2048, 32), + ], +) +@pytest.mark.parametrize('causal', [False, True]) +@pytest.mark.parametrize('dropout_p', [0.0]) +@pytest.mark.parametrize('layout', ["bshd", "thd"]) +@pytest.mark.parametrize('packing', [None, "qkv"]) +@pytest.mark.parametrize('DEBUG_INPUT', [False]) +@pytest.mark.flaky(reruns=3, reason="Retry failures") +@pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device") +def test_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, packing, DEBUG_INPUT): + torch.manual_seed(20) + test_backward = True + device = "cuda" + window_size = (-1, -1) + softcap = 0.0 + alibi_slopes = None + deterministic = False + ref_dtype = torch.float32 + is_varlen = True if layout == "thd" else False + + # skip QKV packing tests for uneven sequence lengths and head sizes + if packing == 'qkv': + if N_CTX_Q != N_CTX_K: + pytest.skip("QKV packing requires N_CTX_Q == N_CTX_K") + if HQ != HK: + pytest.skip("QKV packing requires HQ == HK") + + # test apis + if packing == 'qkv': + # generate inputs + qkv, do, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ref_dtype, layout, packing=packing, device=device, DEBUG_INPUT=DEBUG_INPUT) + + # ---------------------------------------------------------------- + # --- FP8 --- + # ---------------------------------------------------------------- + qkv_fp8 = qkv.clone() + do_fp8= do.clone() + + if is_varlen: + out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_varlen_qkvpacked_fp8_func( + qkv_fp8, + metadata.cu_seqlens_q, + metadata.max_seqlens_q, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + else: + out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_qkvpacked_fp8_func( + qkv_fp8, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + + # ---------------------------------------------------------------- + # --- Reference --- + # ---------------------------------------------------------------- + # reference forward pass + qkv_ref = qkv.clone() + do_ref= do.clone() + + if is_varlen: + out_ref, lse_ref, S_dmask_ref = flash_attn_varlen_qkvpacked_func( + qkv_ref, + metadata.cu_seqlens_q, + metadata.max_seqlens_q, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + else: + out_ref, lse_ref, S_dmask_ref = flash_attn_qkvpacked_func( + qkv_ref, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + + # ---------------------------------------------------------------- + # --- Compare --- + # ---------------------------------------------------------------- + # compare forward + if DEBUG: + print() + print(f"Compare fp8 against ref with dtype {ref_dtype}") -@pytest.mark.parametrize('batch_size, seqlen_q, seqlen_k, group_q, group_k, dim', get_input_shapes()) -def test_op_fwd_decode(batch_size, seqlen_q, seqlen_k, group_q, group_k, dim, dtype=torch.bfloat16): - if DEBUG: - print() - print(f"batch_size = {batch_size}, seqlen_q = {seqlen_q}, seqlen_k = {seqlen_k}, group_q = {group_q}, group_k = {group_k}, dim = {dim}") + if DEBUG: + print("out_ref:", out_ref, out_ref.shape) + print("out_fp8:", out_fp8, out_fp8.shape) + fp8_assert_close(out_ref, out_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) + + + if DEBUG: + print("lse_ref:", lse_ref, lse_ref.shape) + print("lse_fp8:", lse_fp8, lse_fp8.shape) + fp8_assert_close(lse_ref, lse_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) + + + if dropout_p > 0.0: + if DEBUG: + print("S_dmask_ref:", S_dmask_ref, S_dmask_ref.shape) + print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape) + fp8_assert_close(S_dmask_ref, S_dmask_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) + + if not test_backward: + return + + # fp8 backward pass + dqkv_fp8, = torch.autograd.grad(out_fp8, (qkv_fp8), do_fp8) + + # ref backward pass + dqkv_ref, = torch.autograd.grad(out_ref, (qkv_ref), do_ref) + + # compare backward gradients + if DEBUG: + print("dqkv_ref:", dqkv_ref, dqkv_ref.shape) + print("dqkv_fp8:", dqkv_fp8, dqkv_fp8.shape) + fp8_assert_close(dqkv_ref, dqkv_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) + + elif packing is None: + # generate inputs + q, k, v, do, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ref_dtype, layout, device=device, DEBUG_INPUT=DEBUG_INPUT) + + # ---------------------------------------------------------------- + # --- FP8 --- + # ---------------------------------------------------------------- + if DEBUG: + print() + print(f"Compute Fp8 Forward") + q_fp8 = q.clone() + k_fp8 = k.clone() + v_fp8 = v.clone() + do_fp8= do.clone() + + if is_varlen: + out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_varlen_fp8_func( + q_fp8, + k_fp8, + v_fp8, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + else: + out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_fp8_func( + q_fp8, + k_fp8, + v_fp8, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + + # ---------------------------------------------------------------- + # --- Reference --- + # ---------------------------------------------------------------- + if DEBUG: + print() + print(f"Compute Reference Forward") + # reference forward pass + q_ref = q.clone() + k_ref = k.clone() + v_ref = v.clone() + do_ref = do.clone() + + if is_varlen: + out_ref, lse_ref, S_dmask_ref = flash_attn_varlen_func( + q_ref, + k_ref, + v_ref, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + else: + out_ref, lse_ref, S_dmask_ref = flash_attn_func( + q_ref, + k_ref, + v_ref, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + + # ---------------------------------------------------------------- + # --- Compare --- + # ---------------------------------------------------------------- + # compare forward + if DEBUG: + print() + print(f"Compare fp8 against ref with dtype {ref_dtype}") + + if DEBUG: + print("out_ref:", out_ref, out_ref.shape) + print("out_fp8:", out_fp8, out_fp8.shape) + # torch.testing.assert_close(out_ref, out_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) + fp8_assert_close(out_ref, out_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) + + + if DEBUG: + print("lse_ref:", lse_ref, lse_ref.shape) + print("lse_fp8:", lse_fp8, lse_fp8.shape) + # torch.testing.assert_close(lse_ref, lse_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) + fp8_assert_close(lse_ref, lse_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) + + + if dropout_p > 0.0: + if DEBUG: + print("S_dmask_ref:", S_dmask_ref, S_dmask_ref.shape) + print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape) + # torch.testing.assert_close(S_dmask_ref, S_dmask_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) + fp8_assert_close(S_dmask_ref, S_dmask_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) + + if not test_backward: + return + + if DEBUG: + print() + print(f"Compute Fp8 Backward") + # fp8 backward pass + dq_fp8, dk_fp8, dv_fp8 = torch.autograd.grad(out_fp8, (q_fp8, k_fp8, v_fp8), do_fp8) + + if DEBUG: + print() + print(f"Compute Reference Backward") + # ref backward pass + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), do_ref) + + # compare backward gradients + if DEBUG: + print("dv_ref:", dv_ref, dv_ref.shape) + print("dv_fp8:", dv_fp8, dv_fp8.shape) + # torch.testing.assert_close(dv_ref, dv_fp8, atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN) + fp8_assert_close(dv_ref, dv_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) + + if DEBUG: + print("dk_ref:", dk_ref, dk_ref.shape) + print("dk_fp8:", dk_fp8, dk_fp8.shape) + # torch.testing.assert_close(dk_ref, dk_fp8, atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN) + fp8_assert_close(dk_ref, dk_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) + + if DEBUG: + print("dq_ref:", dq_ref, dq_ref.shape) + print("dq_fp8:", dq_fp8, dq_fp8.shape) + # torch.testing.assert_close(dq_ref, dq_fp8, atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN) + fp8_assert_close(dq_ref, dq_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) + +@pytest.mark.parametrize( + "BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", + [ + (2, 4, 4, 512, 512, 128), + ], +) +@pytest.mark.parametrize('causal', [False, True]) +@pytest.mark.parametrize('dropout_p', [0.0, 0.1]) +@pytest.mark.parametrize('layout', ['bshd']) +@pytest.mark.parametrize('packing', [None]) +@pytest.mark.parametrize('test_backward', [False, True]) +@pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device") +@pytest.mark.skip("Breaks on CI but works locally") +def test_ir(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, packing, test_backward): # Don't run this test in parallel. It clears the cache so it doesnot work properly if run in parallel. torch.manual_seed(20) - query_group_head_size = (group_q + group_k - 1) // group_k - q = (torch.empty((batch_size, seqlen_q, group_k, query_group_head_size, dim), dtype=dtype, - device="cuda").normal_(mean=0., std=0.5).requires_grad_()) - k = (torch.empty((batch_size, seqlen_k, group_k, 1, dim), dtype=dtype, - device="cuda").normal_(mean=0., - std=0.5).requires_grad_()).expand(-1, -1, -1, query_group_head_size, -1) - v = (torch.empty((batch_size, seqlen_k, group_k, 1, dim), dtype=dtype, - device="cuda").normal_(mean=0., - std=0.5).requires_grad_()).expand(-1, -1, -1, query_group_head_size, -1) - scale = 1 / dim**0.5 - input_metadata = MetaData(sm_scale=scale) - input_metadata.layout = "bsghd" - tri_out, _ = attention_decode(q, k, v, input_metadata) - - q = q.reshape([batch_size, seqlen_q, -1, dim]).permute(0, 2, 1, 3) - k = k.reshape([batch_size, seqlen_k, -1, dim]).permute(0, 2, 1, 3) - v = v.reshape([batch_size, seqlen_k, -1, dim]).permute(0, 2, 1, 3) - attn = (q @ k.transpose(-1, -2) * scale).softmax(-1) - ref_out = attn @ v - - # compare - torch.testing.assert_close(ref_out, tri_out, atol=1e-3, rtol=0) - -def test_quantization(): - a = torch.randn((2, 4, 32), dtype=torch.float16, device='cuda') - qa = quantize_kv_int4(a, num_groups=4) - dqa = dequantize_kv_fp16(qa, num_groups=4) - torch.testing.assert_close(a, dqa, atol=1.5e-1, rtol=1e-1) - -@pytest.mark.parametrize('B, Mq, Mkv, Hq, Hkv, K', get_input_shapes()) -def test_op_fwd_decode_int4_kv(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16): - pytest.skip("Decode kernel doesnot support quantization yet") - torch.manual_seed(2) - q = (torch.empty((B, Mq, Hkv, (Hq + Hkv - 1) // Hkv, K), dtype=dtype, - device="cuda").normal_(mean=1.0, std=0.5).requires_grad_()) - k = (torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, - device="cuda").normal_(mean=1.0, - std=0.5).requires_grad_()).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) - v = (torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, - device="cuda").normal_(mean=1.0, - std=0.5).requires_grad_()).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) - - num_groups = 1 - quant_k = (quantize_kv_int4(k, num_groups=num_groups).contiguous().view(torch.int32)) - quant_v = (quantize_kv_int4(v, num_groups=num_groups).contiguous().view(torch.int32)) - scale = 1 / K**0.5 - input_metadata = MetaData(sm_scale=scale) - input_metadata.layout = "bsghd" - tri_out, _ = attention_decode(q, quant_k, quant_v, input_metadata) - - q = q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3) - k = k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) - v = v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) - attn = (q @ k.transpose(-1, -2) * scale).softmax(-1) - ref_out = attn @ v - # compare - torch.testing.assert_close(ref_out, tri_out, atol=2.1e-2, rtol=0) - - # since quantization introduces rounding error, use the - # dequantized kv as inputs to the ref implementation to reduce - # the tolerance to 1e-3 - dqk = dequantize_kv_fp16(quant_k, num_groups=num_groups) - dqv = dequantize_kv_fp16(quant_v, num_groups=num_groups) - dqk = dqk.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) - dqv = dqv.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) - dq_attn = (q @ dqk.transpose(-1, -2) * scale).softmax(-1) - dq_ref_out = dq_attn @ dqv - torch.testing.assert_close(dq_ref_out, tri_out, atol=1e-3, rtol=0) + device = "cuda" + window_size = (-1, -1) + softcap = 0.0 + alibi_slopes = None + deterministic = False + ref_dtype = torch.float32 + is_varlen = True if layout == "thd" else False + + # remove cache + cache_path = Path(os.path.expanduser("~/.triton/cache")) + if cache_path.exists(): + shutil.rmtree(cache_path) + os.makedirs(cache_path) + + # inputs + q, k, v, do, metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ref_dtype, layout=layout, packing=packing, device=device) + + if packing == None: + # fp8 forward pass + if is_varlen: + out, lse, S_dmask = flash_attn_varlen_fp8_func( + q, + k, + v, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + else: + out, lse, S_dmask = flash_attn_fp8_func( + q, + k, + v, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + + # fp8 backward pass + if test_backward: + dq, dk, dv = torch.autograd.grad(out, (q, k, v), do) + elif packing == "qkv": + # qkv packing path + # pack input tensors (use dim=1 for varlen, else dim=2) + if is_varlen: + qkv = torch.stack([q, k, v], dim=1) + else: + qkv = torch.stack([q, k, v], dim=2) + + # fp8 forward pass for qkv-packed input + if is_varlen: + out, lse, S_dmask = flash_attn_varlen_qkvpacked_fp8_func( + qkv, + metadata.cu_seqlens_q, + metadata.max_seqlens_q, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + else: + out, lse, S_dmask = flash_attn_qkvpacked_fp8_func( + qkv, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + + # fp8 backward pass for qkv-packed input + if test_backward: + dqkv, = torch.autograd.grad(out, (qkv,), do) + else: + raise ValueError(f"unknown packing type {packing}") + + # search for .ttir files + max_retries = 5 + retry_delay = 0.5 + ttir_files = [] + logging.info(f"Checking for .ttir files in {cache_path}...") + for attempt in range(max_retries): + # search for .ttir files recursively within the cache path + ttir_files = glob.glob(str(cache_path) + "/**/*.ttir", recursive=True) + + if ttir_files: + # Files found, log success and exit the loop + logging.info(f"Found {len(ttir_files)} .ttir files on attempt {attempt + 1}.") + break + else: + # Files not found yet + if attempt < max_retries - 1: + # If not the last attempt, wait and log before retrying + logging.warning( + f"No .ttir files found on attempt {attempt + 1}. " + f"Retrying in {retry_delay}s..." + ) + time.sleep(retry_delay) + else: + pytest.fail( + f"FATAL: No .ttir files found in cache {cache_path} " + f"after {max_retries} attempts." + ) + + # check if there is fp8 + ttir_files_fp8_found_status = {} + fp8_types = ['f8E4M3', 'f8E5M2'] + for ttir_file in ttir_files: + base_name = os.path.basename(ttir_file) + with open(ttir_file, 'r') as f: + content = f.read() + + # check content for fp8 + fp8_found = False + for f8_type in fp8_types: + if f8_type in content: + fp8_found = True + ttir_files_fp8_found_status[base_name] = fp8_found + + for file, fp8_found in ttir_files_fp8_found_status.items(): + assert fp8_found, f"{fp8_types} not found in {file}" diff --git a/flash_attn/flash_attn_triton_amd/train.py b/flash_attn/flash_attn_triton_amd/train.py new file mode 100644 index 00000000000..fc5f5d0b1bf --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/train.py @@ -0,0 +1,403 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, Dataset, random_split +import numpy as np +import pandas as pd +from tqdm import tqdm +import matplotlib.pyplot as plt +from datasets import load_dataset +from flash_attn import flash_attn_qkvpacked_func, flash_attn_qkvpacked_fp8_func, flash_attn_varlen_qkvpacked_func, flash_attn_varlen_qkvpacked_fp8_func + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print(f"using device: {device}") + +# ------------------------------- +# Model +# ------------------------------- +class FlashAttention(nn.Module): + def __init__(self, dim, num_heads=8, causal=True, dropout=0.1, qkv_bias=True, use_fp8=False): + super().__init__() + self.use_fp8 = use_fp8 + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.causal = causal + self.dropout_p = dropout + + # qkv and output projections + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + def forward(self, x): + b, n, c = x.shape + # project to qkv + qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + # reshape for flash attention function + qkv_packed = torch.stack([q, k, v], dim=2).reshape(b, n, 3, self.num_heads, self.head_dim) + + # use the appropriate flash attention function + if self.use_fp8: + context = flash_attn_qkvpacked_fp8_func( + qkv_packed, + dropout_p=self.dropout_p, + causal=self.causal + ) + else: + context = flash_attn_qkvpacked_func( + qkv_packed, + dropout_p=self.dropout_p, + causal=self.causal + ) + + # convert back to original shape and type + context = context.reshape(b, n, c) + + # output projection + x = self.proj(context) + + return x + +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4.0, causal=True, dropout=0.1, use_fp8=False): + super().__init__() + self.norm1 = nn.LayerNorm(dim) + self.attn = FlashAttention(dim, num_heads=num_heads, causal=causal, dropout=dropout, use_fp8=use_fp8) + + self.norm2 = nn.LayerNorm(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = nn.Sequential( + nn.Linear(dim, mlp_hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(mlp_hidden_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + +class FlashLM(nn.Module): + def __init__( + self, + vocab_size, + dim=256, + depth=6, + num_heads=8, + mlp_ratio=4.0, + causal=True, + dropout=0.1, + max_seq_len=256, + use_fp8=False + ): + super().__init__() + + # embedding layers + self.token_embedding = nn.Embedding(vocab_size, dim) + self.position_embedding = nn.Parameter(torch.zeros(1, max_seq_len, dim)) + self.dropout = nn.Dropout(dropout) + + # transformer blocks + self.blocks = nn.ModuleList([ + TransformerBlock(dim, num_heads, mlp_ratio, causal=causal, dropout=dropout, use_fp8=use_fp8) + for _ in range(depth) + ]) + + # lm head: project back to vocabulary dimension for each token + self.norm = nn.LayerNorm(dim) + self.lm_head = nn.Linear(dim, vocab_size) + + def forward(self, x): + b, n = x.shape + + # token + positional embedding + x = self.token_embedding(x) + x = x + self.position_embedding[:, :n, :] + x = self.dropout(x) + + # transformer blocks + for block in self.blocks: + x = block(x) + + # language modeling head + x = self.norm(x) + logits = self.lm_head(x) # shape: (b, n, vocab_size) + return logits + +# ------------------------------- +# Data +# ------------------------------- +class TextDataset(Dataset): + def __init__(self, sequences, max_len=None): + self.sequences = sequences + self.max_len = max_len + + def __len__(self): + return len(self.sequences) + + def __getitem__(self, idx): + seq = self.sequences[idx] + # input: all tokens except the last, target: all tokens except the first + return (torch.tensor(seq[:-1], dtype=torch.long), + torch.tensor(seq[1:], dtype=torch.long)) + +class VarLenTextDataset(Dataset): + def __init__(self, sequences, max_len=256): + self.sequences = sequences + self.max_len = max_len + + def __len__(self): + return len(self.sequences) + + def __getitem__(self, idx): + seq = self.sequences[idx] + # Ensure the sequence doesn't exceed max_len+1 + seq = seq[:self.max_len+1] + # input: all tokens except the last, target: all tokens except the first + return (torch.tensor(seq[:-1], dtype=torch.long), + torch.tensor(seq[1:], dtype=torch.long)) + +def prepare_dataset(batch_size, is_varlen=False, min_len=10, max_len=256, ratio_shorter=0.7): + # load the WikiText-2 + dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") + + # build vocabulary + corpus = " ".join([line for line in dataset["text"] if line.strip() != ""]) # join non-empty lines into a single corpus string + tokens = corpus.split() + vocab = sorted(set(tokens)) + word2idx = {word: idx for idx, word in enumerate(vocab)} + token_ids = [word2idx[word] for word in tokens] + + num_workers = 2 + if is_varlen: + # VARIABLE LENGTH: create sequences of different lengths + sequences = [] + for i in range(0, len(token_ids) - max_len, max_len // 2): # overlap to get more sequences + # Decide target length for this sequence + if np.random.random() < ratio_shorter: + # Shorter sequence + target_len = np.random.randint(min_len + 1, max_len + 1) + else: + # Full length sequence + target_len = max_len + 1 + + # Extract sequence up to target length or whatever's available + seq_end = min(i + target_len, len(token_ids)) + seq = token_ids[i:seq_end] + + # Only keep sequences that are long enough + if len(seq) > min_len + 1: # +1 because we need both input and target + sequences.append(seq) + + print(f"Created {len(sequences)} variable-length sequences") + + # Get some statistics + lens = [len(seq) for seq in sequences] + print(f"Sequence length stats: min={min(lens)}, max={max(lens)}, mean={np.mean(lens):.1f}") + + # split dataset + num_samples = len(sequences) + num_train = int(0.8 * num_samples) + num_val = num_samples - num_train + + # Use appropriate dataset class based on whether we need variable length + dataset_class = VarLenTextDataset + train_sequences = sequences[:num_train] + val_sequences = sequences[num_train:] + + train_dataset = dataset_class(train_sequences, max_len) + val_dataset = dataset_class(val_sequences, max_len) + + + # collate function + def collate_fn(batch): + """ + Collate function that creates a flat representation for variable length flash attention. + """ + # Separate inputs and targets + inputs, targets = zip(*batch) + + # Get sequence lengths + seq_lens = torch.tensor([len(x) for x in inputs], dtype=torch.int32) + + # Concatenate inputs and targets into single tensors + flat_inputs = torch.cat(inputs) + flat_targets = torch.cat(targets) + + # Create cumulative sequence lengths tensor + cu_seqlens = torch.zeros(len(seq_lens) + 1, dtype=torch.int32) + cu_seqlens[1:] = torch.cumsum(seq_lens, dim=0) + + # Calculate max sequence length for this batch + max_seqlen = seq_lens.max().item() + + return flat_inputs, flat_targets, seq_lens, cu_seqlens, max_seqlen + + # data loaders + train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn) + val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn) + else: + # FIXED LENGTH: create sequences of length max_len+1 + sequences = [] + for i in range(0, len(token_ids) - max_len, max_len): + seq = token_ids[i : i + max_len + 1] + if len(seq) == max_len + 1: + sequences.append(seq) + + # split dataset + num_samples = len(sequences) + num_train = int(0.8 * num_samples) + num_val = num_samples - num_train + train_dataset, val_dataset = random_split(TextDataset(sequences), [num_train, num_val]) + + # data loaders + train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) + val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + + vocab_size = len(vocab) + print(f"vocab size: {vocab_size}, train samples: {len(train_dataset)}, validation samples: {len(val_dataset)}") + return train_dataloader, val_dataloader, vocab_size + +# ------------------------------- +# Training +# ------------------------------- +def train_lm(model, train_dataloader, val_dataloader, optimizer, criterion, num_epochs): + train_losses = [] + val_losses = [] + for epoch in range(num_epochs): + # Training phase + model.train() + epoch_train_loss = 0.0 + for inputs, targets in tqdm(train_dataloader, desc=f"epoch {epoch+1}/{num_epochs} [train]"): + inputs, targets = inputs.to(device), targets.to(device) + + optimizer.zero_grad() + logits = model(inputs) + loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) + loss.backward() + optimizer.step() + + epoch_train_loss += loss.item() + + epoch_train_loss /= len(train_dataloader) + train_losses.append(epoch_train_loss) + print(f"epoch {epoch+1}/{num_epochs} - train loss: {epoch_train_loss:.4f}") + + # Validation phase + model.eval() + epoch_val_loss = 0.0 + with torch.no_grad(): + for inputs, targets in tqdm(val_dataloader, desc=f"epoch {epoch+1}/{num_epochs} [validation]"): + inputs, targets = inputs.to(device), targets.to(device) + logits = model(inputs) + loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) + epoch_val_loss += loss.item() + epoch_val_loss /= len(val_dataloader) + val_losses.append(epoch_val_loss) + print(f"epoch {epoch+1}/{num_epochs} - validation loss: {epoch_val_loss:.4f}") + + return train_losses, val_losses + +# ------------------------------- +# Main +# ------------------------------- +def main(): + # hyperparameters + batch_size = 16 + num_epochs = 20 + learning_rate = 3e-4 + max_len = 128 # total length including both input and target tokens + is_varlen = False + causal=True + dropout=0.1 + + # prep data + print("Preparing Dataset") + train_dataloader, val_dataloader, vocab_size = prepare_dataset(batch_size, max_len=max_len, is_varlen=is_varlen) + + # create language models + print("Creating Models") + model_normal = FlashLM( + vocab_size=vocab_size, + dim=256, + depth=3, + num_heads=8, + causal=causal, + dropout=dropout, + max_seq_len=max_len, + ).to(device) + + model_fp8 = FlashLM( + vocab_size=vocab_size, + dim=256, + depth=3, + num_heads=8, + causal=causal, + dropout=dropout, + max_seq_len=max_len, + use_fp8=True + ).to(device) + + # Train Normal model + print("Starting training for Normal model...") + optimizer_normal = optim.AdamW(model_normal.parameters(), lr=learning_rate) + criterion = nn.CrossEntropyLoss() + normal_train_losses, normal_val_losses = train_lm( + model_normal, train_dataloader, val_dataloader, optimizer_normal, criterion, num_epochs + ) + torch.save(model_normal.state_dict(), 'flash_lm_normal.pth') + print("Normal model training complete and saved.") + + # Train FP8 model + print("Starting training for FP8 model...") + optimizer_fp8 = optim.AdamW(model_fp8.parameters(), lr=learning_rate) + fp8_train_losses, fp8_val_losses = train_lm( + model_fp8, train_dataloader, val_dataloader, optimizer_fp8, criterion, num_epochs + ) + torch.save(model_fp8.state_dict(), 'flash_lm_fp8.pth') + print("FP8 model training complete and saved.") + + # save losses to csv + epochs = range(1, num_epochs+1) + loss_data = { + "Epoch": epochs, + "Normal_Training_Loss": normal_train_losses, + "Normal_Validation_Loss": normal_val_losses, + "FP8_Training_Loss": fp8_train_losses, + "FP8_Validation_Loss": fp8_val_losses, + } + df_losses = pd.DataFrame(loss_data) + df_losses.to_csv("losses.csv", index=False) + print("Loss data saved to losses.csv") + + # plot Training Loss + plt.figure(figsize=(10, 6)) + plt.plot(epochs, normal_train_losses, label="Normal Training Loss", marker='o') + plt.plot(epochs, fp8_train_losses, label="FP8 Training Loss", marker='x') + plt.xlabel("Epoch") + plt.ylabel("Training Loss") + plt.title("Training Loss Comparison: Normal vs FP8 Flash Attention") + plt.legend() + plt.grid(True) + plt.savefig("training_loss.png") # Saves the training loss plot to disk + plt.show() + + # Plot Validation Loss + plt.figure(figsize=(10, 6)) + plt.plot(epochs, normal_val_losses, label="Normal Validation Loss", marker='o') + plt.plot(epochs, fp8_val_losses, label="FP8 Validation Loss", marker='x') + plt.xlabel("Epoch") + plt.ylabel("Validation Loss") + plt.title("Validation Loss Comparison: Normal vs FP8 Flash Attention") + plt.legend() + plt.grid(True) + plt.savefig("validation_loss.png") # Saves the validation loss plot to disk + plt.show() + + +if __name__ == "__main__": + main() diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 530455063e2..5d3bf02e1f8 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -1,32 +1,58 @@ - +import csv +import math import torch import os +import random +import functools import triton +import triton.language as tl +from typing import Literal, Optional, Union +# ------------------------------- +# Gloabl Variables +# ------------------------------- AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '0').lower() in ('1', 'true', 'yes') DEBUG = os.environ.get('FLASH_ATTENTION_TRITON_AMD_DEBUG', '0').lower() in ('1', 'true', 'yes') +USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes') PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes') - +USE_SINGLE_BWD_KERNEL = os.environ.get('USE_SINGLE_BWD_KERNEL', '0').lower() in ('1', 'true', 'yes') +USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" +USE_TRITON_INTERPRET = os.environ.get('TRITON_INTERPRET', '0').lower() in ('1', 'true', 'yes') +DEBUG_TRITON = os.environ.get('DEBUG_TRITON', '0').lower() in ('1', 'true', 'yes') and USE_TRITON_INTERPRET +DEBUG_TRITON_DETAIL = os.environ.get('DEBUG_TRITON_DETAIL', '0').lower() in ('1', 'true', 'yes') and USE_TRITON_INTERPRET +if USE_TRITON_ROCM: # TODO remove this + random.seed(42) +DROPOUT_USE_PYTORCH = False +DROPOUT_DUMP = False + + +# ------------------------------- +# Metadata +# ------------------------------- class MetaData(): - cu_seqlens_q = None - cu_seqlens_k = None - max_seqlens_q = 0 - max_seqlens_k = 0 - bias = None - alibi_slopes = None - causal = False + cu_seqlens_q: Optional[torch.Tensor] = None + cu_seqlens_k: Optional[torch.Tensor] = None + max_seqlens_q: int = 0 + max_seqlens_k: int = 0 + bias: Optional[torch.Tensor] = None + alibi_slopes: Optional[torch.Tensor] = None + causal: bool = False num_contexts = 0 - varlen = False - layout = None - cache_seqlens = None + varlen: bool = False + layout: Optional[Literal["bshd", "bhsd", "thd"]] = None + cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None cache_batch_idx = None - new_kv = False - seqlen_new = None - k_new = None - v_new = None - dropout_p, return_scores= 0.0, False + packing: Optional[bool] = None + return_scores: bool = False + dropout_p: float = 0.0 + philox_seed: Optional[int] = None + philox_offset : Optional[int]= None # if dropout_p > 0.0 seed the RNG so we get reproducible results for testing. # NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW. - use_exp2 = False + use_exp2: bool = False + rotary_sin: Optional[torch.Tensor] = None + rotary_cos: Optional[torch.Tensor] = None + rotary_interleaved: bool = False + rotary_conjunction: bool = False def __repr__(self) -> str: @@ -44,10 +70,6 @@ def __repr__(self) -> str: f" layout={self.layout},\n" f" cache_seqlens={self.cache_seqlens},\n" f" cache_batch_idx={self.cache_batch_idx},\n" - f" new_kv={self.new_kv},\n" - f" seqlen_new={self.seqlen_new},\n" - f" k_new={self.k_new},\n" - f" v_new={self.v_new},\n" f" dropout_p={self.dropout_p},\n" f" return_scores={self.return_scores}\n" f")") @@ -55,18 +77,17 @@ def __repr__(self) -> str: def __init__(self, sm_scale=1.0): self.sm_scale = sm_scale - def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k): + def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k): self.varlen = True self.layout = 'thd' self.cu_seqlens_q = cu_seqlens_q self.cu_seqlens_k = cu_seqlens_k + self.max_seqlens_q = max_seqlen_q + self.max_seqlens_k = max_seqlen_k + # Without "varlen", there should still be one sequence. assert len(cu_seqlens_q) >= 2 assert len(cu_seqlens_q) == len(cu_seqlens_k) - self.num_contexts = len(cu_seqlens_q) - 1 - for i in range(0, self.num_contexts): - self.max_seqlens_q = max(cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(), self.max_seqlens_q) - self.max_seqlens_k = max(cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(), self.max_seqlens_k) def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k): assert bias.is_cuda @@ -82,17 +103,24 @@ def need_alibi(self, alibi_slopes, batch, nheads): assert alibi_slopes.shape[1] == nheads self.alibi_slopes = alibi_slopes - def need_causal(self): - self.causal = True + def need_causal(self, causal): + self.causal = causal + + def need_rotary(self, sin, cos, rotary_interleaved, rotary_conjunction=False): + self.rotary_sin = sin + self.rotary_cos = cos + self.rotary_interleaved = rotary_interleaved + self.rotary_conjunction = rotary_conjunction - def need_dropout(self, dropout_p, return_scores): + def need_dropout(self, dropout_p, return_softmax = True): self.dropout_p = dropout_p - self.return_scores = return_scores + self.return_softmax = return_softmax + self.philox_seed, self.philox_offset = 0x1BF58, 0x1D4B49 def check_args(self, q, k, v, o): assert q.dim() == k.dim() and q.dim() == v.dim() - batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, self.layout, self.cu_seqlens_q, self.cu_seqlens_k, self.max_seqlens_q, self.max_seqlens_k) + batch, nheads_q, nheads_k, head_size, _, _ = get_shapes_from_layout(q, k, self.layout, self.cu_seqlens_q, self.cu_seqlens_k, self.max_seqlens_q, self.max_seqlens_k) if self.varlen: assert q.dim() == 3 assert self.cu_seqlens_q is not None @@ -100,8 +128,6 @@ def check_args(self, q, k, v, o): assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) # TODO: Remove once bias is supported with varlen assert self.bias is None - # TODO:Remove once dropout is supported with varlen - assert self.dropout_p == 0.0 # assert not self.return_scores else: assert q.dim() == 4 @@ -111,131 +137,545 @@ def check_args(self, q, k, v, o): assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] # TODO: Change assert if we support qkl f8 and v f16 assert q.dtype == k.dtype and q.dtype == v.dtype - assert head_size <= 256 assert o.shape == q.shape assert (nheads_q % nheads_k) == 0 assert self.layout is not None assert self.layout == 'thd' or not self.varlen -def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device="cuda", DEBUG_INPUT=False): - torch.manual_seed(20) +# ------------------------------- +# Input Helper +# ------------------------------- +def random_seqlens_composition(SEQ_LEN, BATCH): + # generate a random composition of N into Z positive parts. + idx = torch.randperm(SEQ_LEN - 1)[: BATCH - 1] + 1 + idx, _ = torch.sort(idx) + breakpoints = torch.cat([ + torch.tensor([0], dtype=torch.long), + idx, + torch.tensor([SEQ_LEN], dtype=torch.long), + ]) + seqlens = (breakpoints[1:] - breakpoints[:-1]).to(torch.int32) + return seqlens + +def generate_varlen_tensor( + total_seqlen: int, + num_heads: int, + head_size: int, + batch_size: Optional[int] = None, + equal_seqlens: bool = False, + device: str = "cuda", + dtype: torch.dtype = torch.float32, + DEBUG_INPUT: bool = False +): + if DEBUG: + print("total_seqlen", total_seqlen) + print("num_heads", num_heads) + print("head_size", head_size) + + # save fp8 type + is_fp8_dtype = is_dtype_fp8(dtype) + if is_fp8_dtype: + og_fp8_dtype = dtype + dtype = torch.float32 + + # get valid batch_size + if batch_size is None: + valid_batch_sizes = [bs for bs in [1, 2, 4, 8, 16, 32, 64] if bs <= total_seqlen] + batch_size = random.choice(valid_batch_sizes) + + # get seqlens + if equal_seqlens: + seqlens = torch.full( + (batch_size,), + total_seqlen // batch_size, + dtype=torch.int32, + device=device + ) + seqlens[-1] += total_seqlen % batch_size + else: + seqlens = random_seqlens_composition(total_seqlen, batch_size).to(device=device) - # Initialize q, k, v - if layout == 'bhsd': - q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD) - k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD) - elif layout == 'bshd': - q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD) - k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD) + # create cumulative sequence lengths + cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32, device=device), seqlens.cumsum(dim=0)]).to(torch.int32).to(device=device) + max_seqlen = torch.max(seqlens).to(torch.int32).item() + + # create varlen tensor + if DEBUG_INPUT: + x = torch.zeros(total_seqlen, num_heads, head_size, dtype=dtype, device=device) + for i in range(batch_size): + start = cu_seqlens[i].item() + end = cu_seqlens[i+1].item() + length = end - start + + x[start:end, :, :] = ( + torch.arange(length, dtype=dtype, device=device) + .view(length, 1, 1) + .expand(length, num_heads, head_size) + ) else: - assert False, f'Got unsupported tensor layout: {layout}' + x = torch.randn((total_seqlen, num_heads, head_size), dtype=dtype, device=device) + if is_fp8_dtype: + # cast to fp8 + x, descale_x = cast_to_fp8(x, og_fp8_dtype, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + x.requires_grad_() + return x, cu_seqlens, max_seqlen, descale_x + else: + x.requires_grad_() + return x, cu_seqlens, max_seqlen + +def generate_bshd_tensor(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype, device="cuda", DEBUG_INPUT=False): + # save fp8 type + is_fp8_dtype = is_dtype_fp8(dtype) + if is_fp8_dtype: + og_fp8_dtype = dtype + dtype = torch.float32 + + # gen tensor + tensor_shape = (BATCH, SEQ_LEN, NUM_HEADS, D_HEAD) if DEBUG_INPUT: - if layout == "bhsd": - q = torch.arange(N_CTX_Q, dtype=dtype, device=device).view(1, 1, N_CTX_Q, 1).expand(*q_tensor_shape).contiguous().requires_grad_() - k = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, 1, N_CTX_K, 1).expand(*k_tensor_shape).contiguous().requires_grad_() - v = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, 1, N_CTX_K, 1).expand(*k_tensor_shape).contiguous().requires_grad_() - elif layout == "bshd": - q = torch.arange(N_CTX_Q, dtype=dtype, device=device).view(1, N_CTX_Q, 1, 1).expand(*q_tensor_shape).contiguous().requires_grad_() - k = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, N_CTX_K, 1, 1).expand(*k_tensor_shape).contiguous().requires_grad_() - v = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, N_CTX_K, 1, 1).expand(*k_tensor_shape).contiguous().requires_grad_() + x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, SEQ_LEN, 1, 1).expand(*tensor_shape).contiguous() + else: + x = torch.randn(tensor_shape, dtype=dtype, device=device) + + if is_fp8_dtype: + # cast to fp8 + x, descale_x = cast_to_fp8(x, og_fp8_dtype, "bshd") + x.requires_grad_() + return x, descale_x else: - q = torch.randn(q_tensor_shape, dtype=dtype, device=device, requires_grad=True) - k = torch.randn(k_tensor_shape, dtype=dtype, device=device, requires_grad=True) - v = torch.randn(k_tensor_shape, dtype=dtype, device=device, requires_grad=True) + x.requires_grad_() + return x + +def generate_bhsd_tensor(BATCH, NUM_HEADS, SEQ_LEN, D_HEAD, dtype, device="cuda", DEBUG_INPUT=False): + # save fp8 type + is_fp8_dtype = is_dtype_fp8(dtype) + if is_fp8_dtype: + og_fp8_dtype = dtype + dtype = torch.float32 + # gen tensor + tensor_shape = (BATCH, NUM_HEADS, SEQ_LEN, D_HEAD) if DEBUG_INPUT: - sm_scale = 1 + x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, 1, SEQ_LEN, 1).expand(*tensor_shape).contiguous() else: - sm_scale = D_HEAD**-0.5 - input_metadata = MetaData(sm_scale=sm_scale) - input_metadata.max_seqlens_q = N_CTX_Q - input_metadata.max_seqlens_k = N_CTX_K - input_metadata.layout = layout - return q, k, v, input_metadata - + x = torch.randn(tensor_shape, dtype=dtype, device=device) + -def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device="cuda", equal_seqlens=False, DEBUG_INPUT=False): + if is_fp8_dtype: + # cast to fp8 + x, descale_x = cast_to_fp8(x, og_fp8_dtype, "bhsd") # FIXME: I don't the casting fn supports this atm + x.requires_grad_() + return x, descale_x + else: + x.requires_grad_() + return x + +def input_helper( + BATCH: int, + HQ: int, + HK: int, + N_CTX_Q: int, + N_CTX_K: int, + D_HEAD: int, + CAUSAL: bool, + DROPOUT_P: float, + dtype: torch.dtype, + layout: Literal["bshd", "bhsd", "thd"], + packing: Optional[Literal["kv", "qkv"]] = None, + device: Literal["cpu", "cuda"] = "cuda", + DEBUG_INPUT: bool = False, +): torch.manual_seed(20) - - # Random or equal sequence lengths based on 'equal_seqlens' flag - if not equal_seqlens: - max_seqlens_q = N_CTX_Q // Z - max_seqlens_k = N_CTX_K // Z - seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z,), dtype=torch.int32) - seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z,), dtype=torch.int32) + is_fp8_dtype = is_dtype_fp8(dtype) + + if layout == "thd": + # set params + TOTAL_SEQLENS_Q = BATCH * N_CTX_Q + TOTAL_SEQLENS_K = BATCH * N_CTX_K + equal_seqlens=False + + # gen tensors + # TODO: the gen functions should maybe have different gen modes like random, ones, increasing seqlen + if is_fp8_dtype: + q, cu_seqlens_q, max_seqlen_q, descale_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + k, cu_seqlens_k, max_seqlen_k, descale_k = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + v, _, _ , descale_v = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + do, _, _ , descale_do = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) + else: + q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + k, cu_seqlens_k, max_seqlen_k = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + v, _, _ = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) + + # setup metadata + if DEBUG_INPUT: + sm_scale = 1 + else: + sm_scale = D_HEAD**-0.5 + metadata = MetaData(sm_scale=sm_scale) + metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) + metadata.need_causal(CAUSAL) + metadata.need_dropout(DROPOUT_P) + elif layout == 'bshd' or layout == "bhsd": + # gen tensors + if layout == "bshd": + if is_fp8_dtype: + q, descale_q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + k, descale_k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + v, descale_v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + do, descale_do = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device) + else: + q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) + elif layout == "bhsd": + if is_fp8_dtype: + q, descale_q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + k, descale_k = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + v, descale_v = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + do, descale_do = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device) + else: + q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + k = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + v = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) + + # setup metadata + if DEBUG_INPUT: + sm_scale = 1 + else: + sm_scale = D_HEAD**-0.5 + metadata = MetaData(sm_scale=sm_scale) + metadata.max_seqlens_q = N_CTX_Q + metadata.max_seqlens_k = N_CTX_K + metadata.layout = layout + metadata.need_causal(CAUSAL) + metadata.need_dropout(DROPOUT_P) else: - seqlens_q = torch.full((Z,), N_CTX_Q // Z, dtype=torch.int32) - seqlens_k = torch.full((Z,), N_CTX_K // Z, dtype=torch.int32) + raise ValueError(f"Unknown layout: {layout}") - # Calculate cumulative sequence lengths - cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0)]) - cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0)]) - cu_seqlens_q = cu_seqlens_q.to(device=device).to(torch.int32) - cu_seqlens_k = cu_seqlens_k.to(device=device).to(torch.int32) + # deal with packing + if packing is None: + if is_fp8_dtype: + return (q, descale_q), (k, descale_k), (v, descale_v), (do, descale_do), metadata + else: + return q, k, v, do, metadata + elif packing == "kv": + # pack k and v + if layout in ["bhsd", "thd"]: + kv = torch.stack([k, v], dim=1) + elif layout == "bshd": + kv = torch.stack([k, v], dim=2) + else: + raise ValueError(f"Unknown layout: {layout}") - # Total lengths - total_q = cu_seqlens_q[-1].item() - total_k = cu_seqlens_k[-1].item() + if is_fp8_dtype: + raise ValueError("FP8 not supported kv packing yet") + else: + return q, kv, do, metadata + elif packing == "qkv": + # qkv packing - requires same sequence length for q and k + assert N_CTX_Q == N_CTX_K, "For QKV packing, Q and K must have same sequence length" + assert HQ == HK, "For QKV packing, Q and K must have same number of heads" + + # pack q, k, and v + if layout in ["bhsd", "thd"]: + qkv = torch.stack([q, k, v], dim=1) + elif layout == "bshd": + qkv = torch.stack([q, k, v], dim=2) + else: + raise ValueError(f"Unknown layout: {layout}") - if DEBUG_INPUT: - # Initialize q, k, v with deterministic values - q = torch.arange(total_q, dtype=dtype, device=device).view(total_q, 1, 1) - q = q.expand(total_q, HQ, D_HEAD).contiguous().requires_grad_() - k = torch.arange(total_k, dtype=dtype, device=device).view(total_k, 1, 1) - k = k.expand(total_k, HK, D_HEAD).contiguous().requires_grad_() - v = torch.arange(total_k, dtype=dtype, device=device).view(total_k, 1, 1) - v = v.expand(total_k, HK, D_HEAD).contiguous().requires_grad_() - sm_scale = 1 + if is_fp8_dtype: + raise ValueError("FP8 not supported qkv packing yet") + else: + return qkv, do, metadata else: - # Initialize q, k, v with random values - q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device=device).requires_grad_() - k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device=device).requires_grad_() - v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device=device).requires_grad_() - sm_scale = D_HEAD ** -0.5 - - input_metadata = MetaData(sm_scale=sm_scale) - input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) - return q, k, v, input_metadata - - -def get_shape_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = None, max_seqlen_q=None, max_seqlen_k=None): + assert False, f"Unsupported packing mode: {packing}" + +# ------------------------------- +# Alibi +# ------------------------------- +@triton.jit +def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False): + # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix + # for casual mask we want something like this where (1 is kept and 0 is masked) + # seqlen_q = 2 and seqlen_k = 5 + # 1 1 1 1 0 + # 1 1 1 1 1 + # seqlen_q = 5 and seqlen_k = 2 + # 0 0 + # 0 0 + # 0 0 + # 1 0 + # 1 1 + # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal + # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False + # 1. offs_m[:,None] = [[0], + # [1], + # 2. offs_m[:,None] + seqlen_k = [[5], + # [6], + # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], + # [4], + # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], + # [4], [ 4, 3, 2, 1, 0]] + # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], + # [ -4, -3, -2, -1, 0]], + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + if transpose: + return alibi_block.T + else: + return alibi_block + +# ------------------------------- +# FP8 +# ------------------------------- +def is_dtype_fp8(dtype): + if dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz}: + if arch_supports_fp8(): + return True + else: + raise RuntimeError("This device doesnot support fp8") + else: + return False + +def is_fp8(x): + return is_dtype_fp8(x.dtype) + +@triton.jit +def compute_fp8_scaling_factors(x, fp8_max: tl.constexpr): + # compute fp8 scaling and descaling factor for a block + x_amax = tl.max(tl.abs(x)) # NOTE: abs deals with negative values + x_amax = tl.where(x_amax <= 1e-9, 1e-9, x_amax) + scale_x = fp8_max / x_amax + descale_x = x_amax / fp8_max + return scale_x, descale_x + +@triton.jit +def _cast_varlen_to_fp8_kernel_2d( + X, X_fp8, Descale, + cu_seqlens, H, MAX_SEQLEN, + stride_batch, stride_seq, stride_head, stride_dim, + stride_out_batch, stride_out_seq, stride_out_head, stride_out_dim, + stride_desc_batch, stride_desc_head, + FP8_CLAMP_VAL, + FP8_MAX, + BLOCK_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + IS_VARLEN: tl.constexpr + ): + # Process one (batch, head) pair per kernel + b_id = tl.program_id(0) + h_id = tl.program_id(1) + + # Get sequence bounds for this batch + if IS_VARLEN: + seq_start = tl.load(cu_seqlens + b_id) + seq_end = tl.load(cu_seqlens + b_id + 1) + seqlen = seq_end - seq_start + else: + seq_start = 0 + seqlen = MAX_SEQLEN + + # initialize max value tracker + x_max_val = 0.0 + + # STEP 1: Find max absolute value across the entire sequence + num_of_blocks = tl.cdiv(seqlen, BLOCK_SIZE) + for blk_idx in range(0, num_of_blocks): + # print("blk_idx:", blk_idx) + # offsets + offs_seq = blk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_dim = tl.arange(0, HEAD_DIM) + + # Create mask for valid elements + mask_seq = offs_seq[:, None] < seqlen + if ACTUAL_HEAD_DIM != HEAD_DIM: + mask_dim = offs_dim[None, :] < ACTUAL_HEAD_DIM + mask_seq = mask_seq & mask_dim + + # Load block + adj_x = b_id * stride_batch + h_id * stride_head + seq_start * stride_seq + offs_seq[:, None] * stride_seq + offs_dim[None, :] * stride_dim + x_block = tl.load(X + adj_x, mask=mask_seq, other=0.0) + # print("x_block:", x_block) + + # Find max absolute value in this block + block_max = tl.max(tl.abs(x_block)) + # print("block_max:", block_max) + + # Update overall max + x_max_val = tl.maximum(x_max_val, block_max) + # print("x_max_val:", x_max_val) + + # clamp to avoid division by zero issues + x_max_val = tl.maximum(x_max_val, FP8_CLAMP_VAL) + + # compute scale and descale factors for the entire sequence + scale = FP8_MAX / x_max_val + descale = x_max_val / FP8_MAX + + # store descale factor for this (batch, head) pair + desc_ptr = Descale + b_id * stride_desc_batch + h_id# * stride_desc_head + tl.store(desc_ptr, descale) + + # STEP 2: Apply scaling to the entire sequence and convert to FP8 + for blk_idx in range(0, num_of_blocks): + # offsets + offs_seq = blk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_dim = tl.arange(0, HEAD_DIM) + + # Create mask for valid elements + mask_seq = offs_seq[:, None] < seqlen + if ACTUAL_HEAD_DIM != HEAD_DIM: + mask_dim = offs_dim[None, :] < ACTUAL_HEAD_DIM + mask_seq = mask_seq & mask_dim + + # Load block - Using the fixed addressing + addr = b_id * stride_batch + h_id * stride_head + seq_start * stride_seq + offs_seq[:, None] * stride_seq + offs_dim[None, :] * stride_dim + x_block = tl.load(X + addr, mask=mask_seq, other=0.0) + + # Apply scale and convert to FP8 + x_fp8_block = (x_block * scale).to(X_fp8.type.element_ty) + + # Store results + addr_out = b_id * stride_out_batch + h_id * stride_out_head + seq_start * stride_out_seq + offs_seq[:, None] * stride_out_seq + offs_dim[None, :] * stride_out_dim + tl.store(X_fp8 + addr_out, x_fp8_block, mask=mask_seq) + +def cast_to_fp8( + x: torch.Tensor, + fp8_dtype: torch.dtype, + layout: Literal["bshd", "thd"], + clamp_val: float = 1e-9, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None +) -> tuple[torch.Tensor, torch.Tensor]: + if False: + print() + print("cast_to_fp8") + print("x:", x, x.shape) + print("fp8_dtype:", fp8_dtype) + print("cu_seqlens:", cu_seqlens) + print("max_seqlen:", max_seqlen) + print("clamp_val:", clamp_val) + + # check types are valid + assert x.dtype in {torch.float16, torch.float32, torch.float64, torch.bfloat16} and is_dtype_fp8(fp8_dtype), f"Cannot cast {x.dtype} to {fp8_dtype}" + + # extract dimensions + batch, max_seqlen_final, num_heads, head_dim = get_shape_from_layout(x, layout, cu_seqlens, max_seqlen) + is_varlen = layout == "thd" + fp8_max = torch.finfo(fp8_dtype).max + if False: + print("batch:", batch) + print("max_seqlen_final:", max_seqlen_final) + print("num_heads:", num_heads) + print("head_dim:", head_dim) + + # get closest power of 2 for head_dim + padded_head_dim = 1 << (head_dim - 1).bit_length() + padded_head_dim = max(padded_head_dim, 32) + + # kernel params + x_fp8 = torch.zeros_like(x, dtype=fp8_dtype) + descale_factors = torch.zeros((batch, num_heads), device=x.device, dtype=torch.float32) + BLOCK_SIZE = 128 + + # calculate strides + stride_batch, stride_head, stride_seq, stride_dim = get_stride_from_layout(x, layout) + stride_out_batch, stride_out_head, stride_out_seq, stride_out_dim = get_stride_from_layout(x_fp8, layout) + stride_desc_batch, stride_desc_head = descale_factors.stride() + + if False: + print("stride_batch", stride_batch) + print("stride_head", stride_head) + print("stride_seq", stride_seq) + print("stride_dim", stride_dim) + print("stride_out_batch", stride_out_batch) + print("stride_out_head", stride_out_head) + print("stride_out_seq", stride_out_seq) + print("stride_out_dim", stride_out_dim) + print("stride_desc_batch", stride_desc_batch) + print("stride_desc_head", stride_desc_head) + + grid = (batch, num_heads) + _cast_varlen_to_fp8_kernel_2d[grid]( + x, x_fp8, descale_factors, + cu_seqlens, num_heads, max_seqlen_final, + stride_batch, stride_seq, stride_head, stride_dim, + stride_out_batch, stride_out_seq, stride_out_head, stride_out_dim, + stride_desc_batch, stride_desc_head, + clamp_val, fp8_max, + BLOCK_SIZE=BLOCK_SIZE, + HEAD_DIM=padded_head_dim, + ACTUAL_HEAD_DIM=head_dim, + IS_VARLEN=is_varlen + ) + + if False: + print("x_fp8:", x_fp8, x_fp8.shape) + print("descale_factors:", descale_factors, descale_factors.shape) + return x_fp8, descale_factors + +# ------------------------------- +# Misc +# ------------------------------- +def get_shape_from_layout( + x: torch.Tensor, + layout: Literal["bshd", "bhsd", "thd"], + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +) -> tuple[int, int, int, int]: if layout == 'bhsd': - batch_q, nheads_q, max_seqlen_q, head_size_q = q.shape - batch_k, nheads_k, max_seqlen_k, head_size_k = k.shape + batch, num_heads, max_seqlen_final, head_dim = x.shape elif layout == 'bshd': - batch_q, max_seqlen_q, nheads_q, head_size_q = q.shape - batch_k, max_seqlen_k, nheads_k, head_size_k = k.shape + batch, max_seqlen_final, num_heads, head_dim = x.shape elif layout == 'thd': - batch_q, max_seqlen_q, nheads_q, head_size_q = len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], q.shape[2] - batch_k, max_seqlen_k, nheads_k, head_size_k = len(cu_seqlens_k) - 1, max_seqlen_k, k.shape[1], k.shape[2] + total_seqlen, num_heads, head_dim = x.shape + if cu_seqlens is None: + raise ValueError("cu_seqlens must be provided for varlen (thd) layout") + if max_seqlen is None: + raise ValueError("max_seqlen must be provided for varlen (thd) layout") + + batch, max_seqlen_final, num_heads, head_dim = len(cu_seqlens) - 1, max_seqlen, num_heads, head_dim else: assert False, "Got unsupported layout." + + return batch, max_seqlen_final, num_heads, head_dim + + +def get_shapes_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = None, max_seqlen_q=None, max_seqlen_k=None): + batch_q, seqlen_q, nheads_q, head_size_q = get_shape_from_layout(q, layout, cu_seqlens_q, max_seqlen_q) + batch_k, seqlen_k, nheads_k, head_size_k = get_shape_from_layout(k, layout, cu_seqlens_k, max_seqlen_k) # assert assert batch_q == batch_k assert head_size_q == head_size_k - return batch_q, nheads_q, nheads_k, head_size_q, max_seqlen_q, max_seqlen_k + return batch_q, nheads_q, nheads_k, head_size_q, seqlen_q, seqlen_k -def get_strides_from_layout(q, k, v, o, layout): +def get_stride_from_layout(x: torch.Tensor, layout:Literal["bshd", "bhsd", "thd"]): if layout == 'thd': - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) - v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) - o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + strides = (0, x.stride(1), x.stride(0), x.stride(2)) elif layout == 'bhsd': - q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3)) - k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3)) - v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3)) - o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3)) + strides = (x.stride(0), x.stride(1), x.stride(2), x.stride(3)) elif layout == 'bshd': - q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) - k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) - v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) - o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + strides = (x.stride(0), x.stride(2), x.stride(1), x.stride(3)) else: assert False, 'Got unsupported layout.' + return strides + +def get_shape_and_strides_from_layout(x: torch.Tensor, layout: Literal["bshd", "bhsd", "thd"], cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None): + return get_shape_from_layout(x, layout, cu_seqlens, max_seqlen), get_stride_from_layout(x, layout) + +def get_strides_from_layout(q, k, v, o, layout): + q_strides = get_stride_from_layout(q, layout) + k_strides = get_stride_from_layout(k, layout) + v_strides = get_stride_from_layout(v, layout) + o_strides = get_stride_from_layout(o, layout) return q_strides, k_strides, v_strides, o_strides def get_padded_headsize(size): @@ -246,29 +686,90 @@ def get_padded_headsize(size): padded_d_model = max(padded_d_model, 16) return padded_d_model - -def _strides(x: torch.Tensor, *stride_names: str): - if x is None: - return {f"stride_{s}": 0 for i, s in enumerate(stride_names)} - - assert x.ndim == len(stride_names) - return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)} - -def get_input_shapes(): - cases = [(max(1, 2**(16 - i)), 1, 2**i, 16, 1, 128) - for i in range(8, 18)] + [(max(1, 2**(16 - i)), 1, 2**i, 16, 2, 128) for i in range(8, 18)] - return cases - +def compute_alibi_tensor_ref(alibi_slopes, seqlen_q, seqlen_k): + q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) + k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K) + relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K) + return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K) + +# ------------------------------- +# Dropouts +# ------------------------------- +def create_dropout_mask(dropout_p, shape, seed): + device = "cuda" + rand_vals = torch.rand(shape, generator=torch.Generator(device=device).manual_seed(seed), device=device, dtype=torch.float32) + return rand_vals > dropout_p + +def create_dropout_mask_varlen(dropout_p, batch, nheads_q, cu_seqlens_q, cu_seqlens_k, philox_seed): + device = "cuda" + qlens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]) + klens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]) + max_qlen = qlens.max() + max_klen = klens.max() + dropout_mask = torch.zeros((batch, nheads_q, max_qlen, max_klen), device=device) + for b in range(batch): + qlen = qlens[b] + klen = klens[b] + rand_vals = torch.rand((nheads_q, qlen, klen), generator=torch.Generator(device=device).manual_seed(philox_seed), device=device, dtype=torch.float32) + submask = rand_vals > dropout_p + dropout_mask[b, :, :qlen, :klen] = submask + + return dropout_mask + +def write_dropout_mask(x, tensor_name = "tensor"): + batch, head, seqlen_m, seqlen_n = x.shape + x = x.tolist() + + with open(f'{tensor_name}.csv', 'w') as f: + writer = csv.writer(f) + for b in range(batch): + for h in range(head): + dropout_mask = x[b][h] + if True: + BLOCK_M = 64 + BLOCK_N = 64 + + # Calculate number of blocks in each dimension + m_blocks = math.ceil(seqlen_m / BLOCK_M) + n_blocks = math.ceil(seqlen_n / BLOCK_N) + + # Process each block + for m_block in range(m_blocks): + # Calculate row range for current block + row_start = m_block * BLOCK_M + row_end = min(row_start + BLOCK_M, seqlen_m) + + for n_block in range(n_blocks): + # Calculate column range for current block + col_start = n_block * BLOCK_N + col_end = min(col_start + BLOCK_N, seqlen_n) + + # Extract and write the current block + for row_idx in range(row_start, row_end): + row_data = dropout_mask[row_idx][col_start:col_end] + writer.writerow(row_data) + else: + writer.writerows(dropout_mask) + +# ------------------------------- +# Runtime info +# ------------------------------- +@functools.cache def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" +@functools.cache +def get_arch(): + return triton.runtime.driver.active.get_current_target().arch +@functools.cache def is_cdna(): - return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942', - 'gfx90a', 'gfx908') - + return is_hip() and get_arch() in ('gfx908', 'gfx90a', 'gfx940', 'gfx941', 'gfx942', 'gfx950') +@functools.cache def is_rdna(): - return is_hip() and triton.runtime.driver.active.get_current_target().arch in ("gfx1030", "gfx1100", "gfx1101", - "gfx1102", "gfx1200", "gfx1201") + return is_hip() and get_arch() in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201") +@functools.cache +def arch_supports_fp8(): + return is_hip() and get_arch() in ('gfx942') diff --git a/flash_attn/fused_softmax.py b/flash_attn/fused_softmax.py deleted file mode 100644 index 382f94f092c..00000000000 --- a/flash_attn/fused_softmax.py +++ /dev/null @@ -1,201 +0,0 @@ -# [2022-10-23] Copied from https://github.com/NVIDIA/apex/blob/master/apex/transformer/functional/fused_softmax.py -# for benchmarking. -# We added support for seqlen=2k and seqlen=4k - -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch -from apex._autocast_utils import _cast_if_autocast_enabled -from apex.transformer.enums import AttnMaskType -from fused_softmax_lib import ( - scaled_masked_softmax_backward, - scaled_masked_softmax_forward, - scaled_masked_softmax_get_batch_per_block, - scaled_upper_triang_masked_softmax_backward, - scaled_upper_triang_masked_softmax_forward, -) - - -class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply upper triangular mask (typically used in gpt models). - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, scale): - scale_t = torch.tensor([scale]) - softmax_results = scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - softmax_results, scale_t = ctx.saved_tensors - input_grads = scaled_upper_triang_masked_softmax_backward( - output_grads, softmax_results, scale_t[0] - ) - return input_grads, None - - -def scaled_upper_triang_masked_softmax(inputs, _, scale): - b, np, sq, sk = inputs.size() - assert sq == sk, "causal mask is only for self attention" - # Reshaping input to 3D tensor (attn_batches, sq, sk) - inputs = inputs.view(-1, sq, sk) - args = _cast_if_autocast_enabled(inputs, scale) - with torch.cuda.amp.autocast(enabled=False): - probs = ScaledUpperTriangMaskedSoftmax.apply(*args) - return probs.view(b, np, sq, sk) - - -# NOTE (mkozuki): `ScaledMaskedSoftmax` somehow doesn't work well with `torch.cuda.amp.custom_fwd`. -# Without `cast_inputs` kwarg, somehow inputs are not cast to dtype used in the autocast context. -# So I needed to manually write two `torch.autograd.Function` inheritances. -# Fused operation which performs following three operations in sequence -# 1. Scale the tensor. -# 2. Apply the mask. -# 3. Perform softmax. -class ScaledMaskedSoftmax(torch.autograd.Function): - @staticmethod - def forward(ctx, inputs, mask, scale): - scale_t = torch.tensor([scale]) - softmax_results = scaled_masked_softmax_forward(inputs, mask, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - softmax_results, scale_t = ctx.saved_tensors - input_grads = scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0]) - return input_grads, None, None - - -def scaled_masked_softmax(inputs, mask, scale): - # input is 4D tensor (b, np, sq, sk) - args = _cast_if_autocast_enabled(inputs, mask, scale) - with torch.cuda.amp.autocast(enabled=False): - return ScaledMaskedSoftmax.apply(*args) - - -class FusedScaleMaskSoftmax(torch.nn.Module): - """ - fused operation: scaling + mask + softmax - - Arguments: - input_in_fp16: flag to indicate if input in fp16 data format. - input_in_bf16: flag to indicate if input in bf16 data format. - attn_mask_type: attention mask type (pad or causal) - scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion - mask_func: mask function to be applied. - softmax_in_fp32: if true, softmax in performed at fp32 precision. - scale: scaling factor used in input tensor scaling. - """ - - def __init__( - self, - input_in_fp16, - input_in_bf16, - attn_mask_type, - scaled_masked_softmax_fusion, - mask_func, - softmax_in_fp32, - scale, - ): - super().__init__() - self.input_in_fp16 = input_in_fp16 - self.input_in_bf16 = input_in_bf16 - if self.input_in_fp16 and self.input_in_bf16: - raise RuntimeError("both fp16 and bf16 flags cannot be active at the same time.") - self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 - self.attn_mask_type = attn_mask_type - self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion - self.mask_func = mask_func - self.softmax_in_fp32 = softmax_in_fp32 - self.scale = scale - - if not (self.scale is None or softmax_in_fp32): - raise RuntimeError("softmax should be in fp32 when scaled") - - if self.scaled_masked_softmax_fusion: - if self.attn_mask_type == AttnMaskType.causal: - self.fused_softmax_func = scaled_upper_triang_masked_softmax - elif self.attn_mask_type == AttnMaskType.padding: - self.fused_softmax_func = scaled_masked_softmax - else: - raise ValueError("Invalid attn_mask_type.") - - def forward(self, input, mask): - # [b, np, sq, sk] - assert input.dim() == 4 - - if self.is_kernel_available(mask, *input.size()): - return self.forward_fused_softmax(input, mask) - else: - return self.forward_torch_softmax(input, mask) - - def is_kernel_available(self, mask, b, np, sq, sk): - attn_batches = b * np - - if ( - self.scaled_masked_softmax_fusion # user want to fuse - and self.input_in_float16 # input must be fp16 - and ( - self.attn_mask_type == AttnMaskType.causal - or (self.attn_mask_type == AttnMaskType.padding and mask is not None) - ) - and 16 < sk <= 8192 # sk must be 16 ~ 8192 - and sq % 4 == 0 # sq must be divisor of 4 - and sk % 4 == 0 # sk must be divisor of 4 - and attn_batches % 4 == 0 # np * b must be divisor of 4 - ): - if 0 <= sk <= 8192: - batch_per_block = self.get_batch_per_block(sq, sk, b, np) - - if self.attn_mask_type == AttnMaskType.causal: - if attn_batches % batch_per_block == 0: - return True - else: - if sq % batch_per_block == 0: - return True - return False - - def forward_fused_softmax(self, input, mask): - # input.shape = [b, np, sq, sk] - scale = self.scale if self.scale is not None else 1.0 - return self.fused_softmax_func(input, mask, scale) - - def forward_torch_softmax(self, input, mask): - if self.input_in_float16 and self.softmax_in_fp32: - input = input.float() - - if self.scale is not None: - input = input * self.scale - mask_output = self.mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) - - if self.input_in_float16 and self.softmax_in_fp32: - if self.input_in_fp16: - probs = probs.half() - else: - probs = probs.bfloat16() - - return probs - - @staticmethod - def get_batch_per_block(sq, sk, b, np): - return scaled_masked_softmax_get_batch_per_block(sq, sk, b, np) diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index 6d021f83910..66baae9ec13 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -1,9 +1,12 @@ # Copyright (c) 2023, Tri Dao. import math +from functools import partial from typing import Optional, Tuple, Union import torch +from torch import Tensor + from einops import rearrange, repeat from flash_attn.ops.triton.rotary import apply_rotary @@ -41,8 +44,8 @@ def forward( sin, interleaved=False, inplace=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, - cu_seqlens: Optional[torch.Tensor] = None, + seqlen_offsets: Union[int, Tensor] = 0, + cu_seqlens: Optional[Tensor] = None, max_seqlen: Optional[int] = None, ): out = apply_rotary( @@ -73,10 +76,6 @@ def backward(ctx, do): cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors else: cos, sin, cu_seqlens = ctx.saved_tensors - # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with - # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works. - if not ctx.interleaved and not ctx.inplace: - do = do.clone() dx = apply_rotary( do, cos, @@ -97,8 +96,8 @@ def apply_rotary_emb( sin, interleaved=False, inplace=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, - cu_seqlens: Optional[torch.Tensor] = None, + seqlen_offsets: Union[int, Tensor] = 0, + cu_seqlens: Optional[Tensor] = None, max_seqlen: Optional[int] = None, ): """ @@ -128,6 +127,70 @@ def apply_rotary_emb( apply_rotary_emb_func = apply_rotary_emb +def _apply_rotary_emb_qkv( + qkv, + cos, + sin, + cos_k=None, + sin_k=None, + interleaved=False, + inplace=False, + conjugate=False, + seqlen_offsets: Union[int, Tensor] = 0, + num_heads_q: Optional[int] = None, +): + apply_rotary_fn = partial( + apply_rotary, + interleaved=interleaved, + inplace=inplace, + conjugate=conjugate, + seqlen_offsets=seqlen_offsets + ) + if cos_k is None and sin_k is None and qkv.is_contiguous(): + # Call 1 kernel instead of 2 kernels + # We need qkv to be contiguous so that when we reshape to combine (3, nheads) + # dimensions, we get the same tensor + if qkv.dim() == 5: + batch, seqlen, three, nheads, headdim = qkv.shape + assert three == 3 + # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") + qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim) + qk = apply_rotary_fn(qk, cos, sin) + else: + assert qkv.dim() == 4 + assert num_heads_q is not None + num_heads_k = (qkv.shape[2] - num_heads_q) // 2 + assert qkv.shape[2] == num_heads_q + 2 * num_heads_k + qk = qkv[:, :, :num_heads_q + num_heads_k] + qk = apply_rotary_fn(qk, cos, sin) + if not inplace: + if qkv.dim() == 5: + qkv = torch.cat([rearrange(qk, "b s (t h) d -> b s t h d", t=2), qkv[:, :, 2:]], dim=2) + else: + qkv = torch.cat([qk, qkv[:, :, num_heads_q + num_heads_k :]], dim=2) + else: + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + if qkv.dim() == 5: + batch, seqlen, three, nheads, headdim = qkv.shape + assert three == 3 + q, k = qkv[:, :, 0], qkv[:, :, 1] + else: + assert qkv.dim() == 4 + assert num_heads_q is not None + num_heads_k = (qkv.shape[2] - num_heads_q) // 2 + assert qkv.shape[2] == num_heads_q + 2 * num_heads_k + q, k = qkv[:, :, :num_heads_q], qkv[:, :, num_heads_q : num_heads_q + num_heads_k] + q = apply_rotary_fn(q, cos, sin) + k = apply_rotary_fn(k, cos_k, sin_k) + if not inplace: + if qkv.dim() == 5: + qkv = torch.stack([q, k, qkv[:, :, 2]], dim=2) + else: + qkv = torch.cat([q, k, qkv[:, :, num_heads_q + num_heads_k:]], dim=2) + return qkv + + class ApplyRotaryEmbQKV_(torch.autograd.Function): @staticmethod def forward( @@ -139,40 +202,13 @@ def forward( sin_k=None, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0, - num_heads_q: Union[int] = None, + num_heads_q: Optional[int] = None, ): - if cos_k is None and sin_k is None and qkv.is_contiguous(): - # Call 1 kernel instead of 2 kernels - # We need qkv to be contiguous so that when we reshape to combine (3, nheads) - # dimensions, we get the same tensor - if qkv.dim() == 5: - batch, seqlen, three, nheads, headdim = qkv.shape - assert three == 3 - # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") - qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim) - else: - assert qkv.dim() == 4 - assert num_heads_q is not None - num_heads_k = (qkv.shape[2] - num_heads_q) // 2 - assert qkv.shape[2] == num_heads_q + 2 * num_heads_k - qk = qkv[:, :, :num_heads_q + num_heads_k] - apply_rotary( - qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True - ) - else: - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - if qkv.dim() == 5: - q, k = qkv[:, :, 0], qkv[:, :, 1] - else: - assert qkv.dim() == 4 - assert num_heads_q is not None - num_heads_k = (qkv.shape[2] - num_heads_q) // 2 - assert qkv.shape[2] == num_heads_q + 2 * num_heads_k - q, k = qkv[:, :, :num_heads_q], qkv[:, :, num_heads_q : num_heads_q + num_heads_k] - apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True) - apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True) - ctx.save_for_backward(cos, sin, cos_k, sin_k) + # apply_rotary_emb_qkv_inplace( + qkv = _apply_rotary_emb_qkv( + qkv, cos, sin, cos_k, sin_k, interleaved=interleaved, inplace=True, + seqlen_offsets=seqlen_offsets, num_heads_q=num_heads_q, + ) if isinstance(seqlen_offsets, int): ctx.save_for_backward(cos, sin, cos_k, sin_k) ctx.seqlen_offsets = seqlen_offsets @@ -190,57 +226,10 @@ def backward(ctx, dqkv): cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors else: cos, sin, cos_k, sin_k = ctx.saved_tensors - if cos_k is None and sin_k is None and dqkv.is_contiguous(): - # Call 1 kernel instead of 2 kernels - # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) - # dimensions, we get the same tensor - if dqkv.dim() == 5: - dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d") - else: - assert dqkv.dim() == 4 - assert ctx.num_heads_q is not None - num_heads_k = (dqkv.shape[2] - ctx.num_heads_q) // 2 - assert dqkv.shape[2] == ctx.num_heads_q + 2 * num_heads_k - dqk = dqkv[:, :, : ctx.num_heads_q + num_heads_k] - apply_rotary( - dqk, - cos, - sin, - seqlen_offsets=seqlen_offsets, - interleaved=ctx.interleaved, - inplace=True, - conjugate=True, - ) - else: - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - if dqkv.dim() == 5: - dq, dk = dqkv[:, :, 0], dqkv[:, :, 1] - else: - assert dqkv.dim() == 4 - assert ctx.num_heads_q is not None - num_heads_k = (dqkv.shape[2] - ctx.num_heads_q) // 2 - assert dqkv.shape[2] == ctx.num_heads_q + 2 * num_heads_k - dq = dqkv[:, :, : ctx.num_heads_q] - dk = dqkv[:, :, ctx.num_heads_q : ctx.num_heads_q + num_heads_k] - apply_rotary( - dq, - cos, - sin, - seqlen_offsets, - interleaved=ctx.interleaved, - inplace=True, - conjugate=True, - ) - apply_rotary( - dk, - cos_k, - sin_k, - seqlen_offsets, - interleaved=ctx.interleaved, - inplace=True, - conjugate=True, - ) + dqkv = _apply_rotary_emb_qkv( + dqkv, cos, sin, cos_k, sin_k, interleaved=ctx.interleaved, inplace=True, + seqlen_offsets=seqlen_offsets, num_heads_q=ctx.num_heads_q, conjugate=True, + ) return dqkv, None, None, None, None, None, None, None @@ -362,27 +351,15 @@ def __init__( base=10000.0, interleaved=False, scale_base=None, - pos_idx_in_fp32=True, device=None, ): """ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style). - pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, - otherwise they might be in lower precision. - This option was added because previously (before 2023-07-02), when we construct - the position indices, we use the dtype of self.inv_freq. In most cases this would - be fp32, but if the model is trained in pure bf16 (not mixed precision), then - self.inv_freq would be bf16, and the position indices are also in bf16. - Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the - embeddings for some positions will coincide. - To maintain compatibility with models previously trained in pure bf16, - we add this option. """ super().__init__() self.dim = dim self.base = float(base) - self.pos_idx_in_fp32 = pos_idx_in_fp32 # Generate and save the inverse frequency buffer (non trainable) inv_freq = self._compute_inv_freq(device) self.register_buffer("inv_freq", inv_freq, persistent=False) @@ -421,19 +398,14 @@ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): self._seq_len_cached = seqlen # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 # And the output of arange can be quite large, so bf16 would lose a lot of precision. - # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. - if self.pos_idx_in_fp32: - t = torch.arange(seqlen, device=device, dtype=torch.float32) - # We want fp32 here as well since inv_freq will be multiplied with t, and the output - # will be large. Having it in bf16 will lose a lot of precision and cause the - # cos & sin output to change significantly. - # We want to recompute self.inv_freq if it was not loaded in fp32 - if self.inv_freq.dtype != torch.float32: - inv_freq = self._compute_inv_freq(device=device) - else: - inv_freq = self.inv_freq + t = torch.arange(seqlen, device=device, dtype=torch.float32) + # We want fp32 here as well since inv_freq will be multiplied with t, and the output + # will be large. Having it in bf16 will lose a lot of precision and cause the + # cos & sin output to change significantly. + # We want to recompute self.inv_freq if it was not loaded in fp32 + if self.inv_freq.dtype != torch.float32: + inv_freq = self._compute_inv_freq(device=device) else: - t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) inv_freq = self.inv_freq # Don't do einsum, it converts fp32 to fp16 under AMP # freqs = torch.einsum("i,j->ij", t, self.inv_freq) @@ -479,26 +451,16 @@ def forward( elif isinstance(seqlen_offset, int): self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) if kv is None: - if self.scale is None: - return apply_rotary_emb_qkv_( - qkv, - self._cos_cached, - self._sin_cached, - interleaved=self.interleaved, - seqlen_offsets=seqlen_offset, - num_heads_q=num_heads_q, - ) - else: - return apply_rotary_emb_qkv_( - qkv, - self._cos_cached, - self._sin_cached, - self._cos_k_cached, - self._sin_k_cached, - interleaved=self.interleaved, - seqlen_offsets=seqlen_offset, - num_heads_q=num_heads_q, - ) + return apply_rotary_emb_qkv_( + qkv, + self._cos_cached, + self._sin_cached, + self._cos_k_cached if self.scale is not None else None, + self._sin_k_cached if self.scale is not None else None, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + num_heads_q=num_heads_q, + ) else: q = qkv q = apply_rotary_emb_func( @@ -509,20 +471,11 @@ def forward( inplace=True, seqlen_offsets=seqlen_offset, ) - if self.scale is None: - kv = apply_rotary_emb_kv_( - kv, - self._cos_cached, - self._sin_cached, - interleaved=self.interleaved, - seqlen_offsets=seqlen_offset, - ) - else: - kv = apply_rotary_emb_kv_( - kv, - self._cos_k_cached, - self._sin_k_cached, - interleaved=self.interleaved, - seqlen_offsets=seqlen_offset, - ) + kv = apply_rotary_emb_kv_( + kv, + self._cos_cached if self.scale is None else self._cos_k_cached, + self._sin_cached if self.scale is None else self._sin_k_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) return q, kv diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 2c0a4f1b871..b2a7f22d243 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -25,7 +25,7 @@ try: from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear except ImportError: - ColumnParallelLinear, RowParallelLinear = None, None, None + ColumnParallelLinear, RowParallelLinear = None, None try: from flash_attn.layers.rotary import RotaryEmbedding diff --git a/flash_attn/ops/triton/layer_norm.py b/flash_attn/ops/triton/layer_norm.py index 0427e957e8e..08119cc99a5 100644 --- a/flash_attn/ops/triton/layer_norm.py +++ b/flash_attn/ops/triton/layer_norm.py @@ -7,14 +7,25 @@ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. import math +from typing import Optional, List import torch import torch.nn.functional as F +from torch import Tensor import triton import triton.language as tl from flash_attn.utils.torch import custom_fwd, custom_bwd +from flash_attn.utils.library import triton_op + + +def maybe_contiguous_lastdim(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +def maybe_contiguous(x): + return x.contiguous() if x is not None else None def triton_autotune_configs(): @@ -25,11 +36,10 @@ def triton_autotune_configs(): # Default to warp size 32 if not defined by device warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32) # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit - warp_count=1 - while warp_count*warp_size <= max_threads_per_block: - configs.append(triton.Config({}, num_warps=warp_count)) - warp_count*=2 - return configs + return [triton.Config({}, num_warps=warp_count) for warp_count in [1, 2, 4, 8, 16, 32] + if warp_count * warp_size <= max_threads_per_block] + # return [triton.Config({}, num_warps=8)] + def layer_norm_ref( x, @@ -43,6 +53,7 @@ def layer_norm_ref( dropout_p=0.0, rowscale=None, prenorm=False, + zero_centered_weight=False, dropout_mask=None, dropout_mask1=None, upcast=False, @@ -56,6 +67,10 @@ def layer_norm_ref( x1 = x1.float() if x1 is not None else None weight1 = weight1.float() if weight1 is not None else None bias1 = bias1.float() if bias1 is not None else None + if zero_centered_weight: + weight = weight + 1.0 + if weight1 is not None: + weight1 = weight1 + 1.0 if x1 is not None: assert rowscale is None, "rowscale is not supported with parallel LayerNorm" if rowscale is not None: @@ -98,6 +113,7 @@ def rms_norm_ref( dropout_p=0.0, rowscale=None, prenorm=False, + zero_centered_weight=False, dropout_mask=None, dropout_mask1=None, upcast=False, @@ -111,6 +127,10 @@ def rms_norm_ref( x1 = x1.float() if x1 is not None else None weight1 = weight1.float() if weight1 is not None else None bias1 = bias1.float() if bias1 is not None else None + if zero_centered_weight: + weight = weight + 1.0 + if weight1 is not None: + weight1 = weight1 + 1.0 if x1 is not None: assert rowscale is None, "rowscale is not supported with parallel LayerNorm" if rowscale is not None: @@ -142,13 +162,13 @@ def rms_norm_ref( @triton.autotune( configs=triton_autotune_configs(), - key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], + key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS", "HAS_X1", "HAS_W1", "HAS_B1"], ) # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) -@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) -@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) -@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) +# @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) +# @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) +# @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) @triton.jit def _layer_norm_fwd_1pass_kernel( X, # pointer to the input @@ -164,6 +184,7 @@ def _layer_norm_fwd_1pass_kernel( ROWSCALE, SEEDS, # Dropout seeds for each row DROPOUT_MASK, + DROPOUT_MASK1, Mean, # pointer to the mean Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row @@ -176,6 +197,7 @@ def _layer_norm_fwd_1pass_kernel( N, # number of columns in X eps, # epsilon to avoid division by zero dropout_p, # Dropout probability + zero_centered_weight, # If true, add 1.0 to the weight IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, HAS_RESIDUAL: tl.constexpr, @@ -226,7 +248,7 @@ def _layer_norm_fwd_1pass_kernel( ) x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) if STORE_DROPOUT_MASK: - tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N) + tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N) x += x1 if HAS_RESIDUAL: residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) @@ -246,6 +268,8 @@ def _layer_norm_fwd_1pass_kernel( # Normalize and apply linear transformation mask = cols < N w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 if HAS_BIAS: b = tl.load(B + cols, mask=mask).to(tl.float32) x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd @@ -254,6 +278,8 @@ def _layer_norm_fwd_1pass_kernel( tl.store(Y + cols, y, mask=mask) if HAS_W1: w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 if HAS_B1: b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 @@ -261,25 +287,87 @@ def _layer_norm_fwd_1pass_kernel( def _layer_norm_fwd( - x, - weight, - bias, - eps, - residual=None, - x1=None, - weight1=None, - bias1=None, - dropout_p=0.0, - rowscale=None, - out_dtype=None, - residual_dtype=None, - is_rms_norm=False, - return_dropout_mask=False, - out=None, - residual_out=None -): + x: Tensor, + weight: Tensor, + bias: Tensor, + eps: float, + residual: Optional[Tensor] = None, + x1: Optional[Tensor] = None, + weight1: Optional[Tensor] = None, + bias1: Optional[Tensor] = None, + dropout_p: float = 0.0, + rowscale: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + residual_dtype: Optional[torch.dtype] = None, + zero_centered_weight: bool = False, + is_rms_norm: bool = False, + return_dropout_mask: bool = False, + out: Optional[Tensor] = None, + residual_out: Optional[Tensor] = None +) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): + # Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library + # and torch.compile unhappy. Also allocate memory for out and residual_out if they are None + # so that _layer_norm_fwd_impl doesn't have to return them. + if out is None: + out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) if residual is not None: residual_dtype = residual.dtype + if residual_out is None and ( + residual is not None + or (residual_dtype is not None and residual_dtype != x.dtype) + or dropout_p > 0.0 + or rowscale is not None + or x1 is not None + ): + residual_out = torch.empty_like( + x, dtype=residual_dtype if residual_dtype is not None else x.dtype + ) + else: + residual_out = None + y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl( + x, + weight, + bias, + eps, + out, + residual=residual, + x1=x1, + weight1=weight1, + bias1=bias1, + dropout_p=dropout_p, + rowscale=rowscale, + zero_centered_weight=zero_centered_weight, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + residual_out=residual_out, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 + if residual_out is None: + residual_out = x + return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 + + +# [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema +# since we're returning a tuple of tensors +@triton_op("flash_attn::layer_norm_fwd_impl", mutates_args={"out", "residual_out"}, + schema="(Tensor x, Tensor weight, Tensor bias, float eps, Tensor(a!) out, Tensor? residual, Tensor? x1, Tensor? weight1, Tensor? bias1, float dropout_p, Tensor? rowscale, bool zero_centered_weight, bool is_rms_norm, bool return_dropout_mask, Tensor(a!)? residual_out) -> (Tensor y1, Tensor mean, Tensor rstd, Tensor seeds, Tensor dropout_mask, Tensor dropout_mask1)") +def _layer_norm_fwd_impl( + x: Tensor, + weight: Tensor, + bias: Tensor, + eps: float, + out: Tensor, + residual: Optional[Tensor] = None, + x1: Optional[Tensor] = None, + weight1: Optional[Tensor] = None, + bias1: Optional[Tensor] = None, + dropout_p: float = 0.0, + rowscale: Optional[Tensor] = None, + zero_centered_weight: bool = False, + is_rms_norm: bool = False, + return_dropout_mask: bool = False, + residual_out: Optional[Tensor] = None +) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): M, N = x.shape assert x.stride(-1) == 1 if residual is not None: @@ -303,33 +391,16 @@ def _layer_norm_fwd( if rowscale is not None: assert rowscale.is_contiguous() assert rowscale.shape == (M,) - # allocate output - if out is None: - out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) - else: - assert out.shape == x.shape + assert out.shape == x.shape assert out.stride(-1) == 1 + if residual_out is not None: + assert residual_out.shape == x.shape + assert residual_out.stride(-1) == 1 if weight1 is not None: y1 = torch.empty_like(out) assert y1.stride(-1) == 1 else: y1 = None - if ( - residual is not None - or (residual_dtype is not None and residual_dtype != x.dtype) - or dropout_p > 0.0 - or rowscale is not None - or x1 is not None - ): - if residual_out is None: - residual_out = torch.empty( - M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype - ) - else: - assert residual_out.shape == x.shape - assert residual_out.stride(-1) == 1 - else: - residual_out = None mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None rstd = torch.empty((M,), dtype=torch.float32, device=x.device) if dropout_p > 0.0: @@ -339,16 +410,20 @@ def _layer_norm_fwd( else: seeds = None if return_dropout_mask and dropout_p > 0.0: - dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool) + dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool) + if x1 is not None: + dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool) + else: + dropout_mask1 = None else: - dropout_mask = None + dropout_mask, dropout_mask1 = None, None # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) if N > BLOCK_N: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") with torch.cuda.device(x.device.index): - _layer_norm_fwd_1pass_kernel[(M,)]( + torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)]( x, out, weight, @@ -362,6 +437,7 @@ def _layer_norm_fwd( rowscale, seeds, dropout_mask, + dropout_mask1, mean, rstd, x.stride(0), @@ -374,6 +450,8 @@ def _layer_norm_fwd( N, eps, dropout_p, + # Passing bool make torch inductor very unhappy since it then tries to compare to int_max + int(zero_centered_weight), is_rms_norm, BLOCK_N, residual is not None, @@ -382,22 +460,11 @@ def _layer_norm_fwd( dropout_p > 0.0, dropout_mask is not None, rowscale is not None, + HAS_X1=x1 is not None, + HAS_W1=weight1 is not None, + HAS_B1=bias1 is not None, ) - # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 - if dropout_mask is not None and x1 is not None: - dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0) - else: - dropout_mask1 = None - return ( - out, - y1, - mean, - rstd, - residual_out if residual_out is not None else x, - seeds, - dropout_mask, - dropout_mask1, - ) + return y1, mean, rstd, seeds, dropout_mask, dropout_mask1 @triton.autotune( @@ -407,11 +474,11 @@ def _layer_norm_fwd( # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) -@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) -@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) -@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) -@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) -@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +# @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) +# @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) +# @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) +# @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) +# @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) @triton.jit def _layer_norm_bwd_kernel( X, # pointer to the input @@ -445,6 +512,7 @@ def _layer_norm_bwd_kernel( N, # number of columns in X eps, # epsilon to avoid division by zero dropout_p, + zero_centered_weight, rows_per_program, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, @@ -478,10 +546,14 @@ def _layer_norm_bwd_kernel( if RECOMPUTE_OUTPUT: Y += row_start * stride_y_row w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 if RECOMPUTE_OUTPUT and HAS_BIAS: b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) if HAS_DY1: w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 dw = tl.zeros((BLOCK_N,), dtype=tl.float32) if HAS_BIAS: db = tl.zeros((BLOCK_N,), dtype=tl.float32) @@ -567,31 +639,93 @@ def _layer_norm_bwd_kernel( def _layer_norm_bwd( - dy, - x, - weight, - bias, - eps, - mean, - rstd, - dresidual=None, - dy1=None, - weight1=None, - bias1=None, - seeds=None, - dropout_p=0.0, - rowscale=None, - has_residual=False, - has_x1=False, - is_rms_norm=False, - x_dtype=None, - recompute_output=False, -): + dy: Tensor, + x: Tensor, + weight: Tensor, + bias: Tensor, + eps: float, + mean: Tensor, + rstd: Tensor, + dresidual: Optional[Tensor] = None, + dy1: Optional[Tensor] = None, + weight1: Optional[Tensor] = None, + bias1: Optional[Tensor] = None, + seeds: Optional[Tensor] = None, + dropout_p: float = 0.0, + rowscale: Optional[Tensor] = None, + has_residual: bool = False, + has_x1: bool = False, + zero_centered_weight: bool = False, + is_rms_norm: bool = False, + x_dtype: Optional[torch.dtype] = None, + recompute_output: bool = False, +) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): + # Need to wrap to handle the case where dresidual_in or dx1 are aliases of x, + # which makes torch.library unhappy + dx, dw, db, dresidual_in, dx1, dw1, db1, y = _layer_norm_bwd_impl( + dy, + x, + weight, + bias, + eps, + mean, + rstd, + dresidual, + dy1, + weight1, + bias1, + seeds, + dropout_p, + rowscale, + has_residual, + has_x1, + zero_centered_weight, + is_rms_norm, + x_dtype=x_dtype, + recompute_output=recompute_output, + ) + # Don't need to compute dresidual_in separately in this case + if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: + dresidual_in = dx + if has_x1 and dropout_p == 0.0: + dx1 = dx + return dx, dw, db, dresidual_in, dx1, dw1, db1, y + + + +@triton_op("flash_attn::layer_norm_bwd_impl", mutates_args={}, + schema="(Tensor dy, Tensor x, Tensor weight, Tensor bias, float eps, Tensor mean, Tensor rstd, Tensor? dresidual, Tensor? dy1, Tensor? weight1, Tensor? bias1, Tensor? seeds, float dropout_p, Tensor? rowscale, bool has_residual, bool has_x1, bool zero_centered_weight, bool is_rms_norm, ScalarType? x_dtype, bool recompute_output) -> (Tensor dx, Tensor dw, Tensor db, Tensor dresidual_in, Tensor dx1, Tensor dw1, Tensor db1, Tensor y)", + allow_decomposition=False, # Don't let torch.compile trace inside + ) +def _layer_norm_bwd_impl( + dy: Tensor, + x: Tensor, + weight: Tensor, + bias: Tensor, + eps: float, + mean: Tensor, + rstd: Tensor, + dresidual: Optional[Tensor] = None, + dy1: Optional[Tensor] = None, + weight1: Optional[Tensor] = None, + bias1: Optional[Tensor] = None, + seeds: Optional[Tensor] = None, + dropout_p: float = 0.0, + rowscale: Optional[Tensor] = None, + has_residual: bool = False, + has_x1: bool = False, + zero_centered_weight: bool = False, + is_rms_norm: bool = False, + x_dtype: Optional[torch.dtype] = None, + recompute_output: bool = False, +) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): M, N = x.shape assert x.stride(-1) == 1 + dy = maybe_contiguous_lastdim(dy) assert dy.stride(-1) == 1 assert dy.shape == (M, N) if dresidual is not None: + dresidual = maybe_contiguous_lastdim(dresidual) assert dresidual.stride(-1) == 1 assert dresidual.shape == (M, N) assert weight.shape == (N,) @@ -600,6 +734,7 @@ def _layer_norm_bwd( assert bias.stride(-1) == 1 assert bias.shape == (N,) if dy1 is not None: + dy1 = maybe_contiguous_lastdim(dy1) assert weight1 is not None assert dy1.shape == dy.shape assert dy1.stride(-1) == 1 @@ -651,7 +786,7 @@ def _layer_norm_bwd( rows_per_program = math.ceil(M / sm_count) grid = (sm_count,) with torch.cuda.device(x.device.index): - _layer_norm_bwd_kernel[grid]( + torch.library.wrap_triton(_layer_norm_bwd_kernel)[grid]( x, weight, bias, @@ -683,6 +818,8 @@ def _layer_norm_bwd( N, eps, dropout_p, + # Passing bool make torch inductor very unhappy since it then tries to compare to int_max + int(zero_centered_weight), rows_per_program, is_rms_norm, BLOCK_N, @@ -690,21 +827,18 @@ def _layer_norm_bwd( dresidual_in is not None, bias is not None, dropout_p > 0.0, + HAS_ROWSCALE=rowscale is not None, + HAS_DY1=dy1 is not None, + HAS_DX1=dx1 is not None, + HAS_B1=bias1 is not None, + RECOMPUTE_OUTPUT=y is not None, ) dw = _dw.sum(0).to(weight.dtype) db = _db.sum(0).to(bias.dtype) if bias is not None else None dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None - # Don't need to compute dresidual_in separately in this case - if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: - dresidual_in = dx - if has_x1 and dropout_p == 0.0: - dx1 = dx - return ( - (dx, dw, db, dresidual_in, dx1, dw1, db1) - if not recompute_output - else (dx, dw, db, dresidual_in, dx1, dw1, db1, y) - ) + # dresidual_in and dx1 could be None, the wrapper will handle assigning them from dx + return dx, dw, db, dresidual_in, dx1, dw1, db1, y class LayerNormFn(torch.autograd.Function): @@ -723,34 +857,27 @@ def forward( rowscale=None, prenorm=False, residual_in_fp32=False, + zero_centered_weight=False, is_rms_norm=False, return_dropout_mask=False, + out_dtype=None, out=None, residual_out=None ): x_shape_og = x.shape # reshape input data into 2D tensor - x = x.reshape(-1, x.shape[-1]) - if x.stride(-1) != 1: - x = x.contiguous() + x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1])) if residual is not None: assert residual.shape == x_shape_og - residual = residual.reshape(-1, residual.shape[-1]) - if residual.stride(-1) != 1: - residual = residual.contiguous() + residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1])) if x1 is not None: assert x1.shape == x_shape_og assert rowscale is None, "rowscale is not supported with parallel LayerNorm" - x1 = x1.reshape(-1, x1.shape[-1]) - if x1.stride(-1) != 1: - x1 = x1.contiguous() + x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1])) weight = weight.contiguous() - if bias is not None: - bias = bias.contiguous() - if weight1 is not None: - weight1 = weight1.contiguous() - if bias1 is not None: - bias1 = bias1.contiguous() + bias = maybe_contiguous(bias) + weight1 = maybe_contiguous(weight1) + bias1 = maybe_contiguous(bias1) if rowscale is not None: rowscale = rowscale.reshape(-1).contiguous() residual_dtype = ( @@ -773,11 +900,13 @@ def forward( bias1, dropout_p=dropout_p, rowscale=rowscale, + out_dtype=out_dtype, residual_dtype=residual_dtype, + zero_centered_weight=zero_centered_weight, is_rms_norm=is_rms_norm, return_dropout_mask=return_dropout_mask, out=out, - residual_out=residual_out + residual_out=residual_out, ) ctx.save_for_backward( residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd @@ -790,6 +919,7 @@ def forward( ctx.has_x1 = x1 is not None ctx.prenorm = prenorm ctx.x_dtype = x.dtype + ctx.zero_centered_weight = zero_centered_weight y = y.reshape(x_shape_og) y1 = y1.reshape(x_shape_og) if y1 is not None else None residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None @@ -818,26 +948,19 @@ def forward( def backward(ctx, dy, *args): x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors dy = dy.reshape(-1, dy.shape[-1]) - if dy.stride(-1) != 1: - dy = dy.contiguous() - assert dy.shape == x.shape if weight1 is not None: dy1, args = args[0], args[1:] dy1 = dy1.reshape(-1, dy1.shape[-1]) - if dy1.stride(-1) != 1: - dy1 = dy1.contiguous() assert dy1.shape == x.shape else: dy1 = None if ctx.prenorm: dresidual = args[0] dresidual = dresidual.reshape(-1, dresidual.shape[-1]) - if dresidual.stride(-1) != 1: - dresidual = dresidual.contiguous() assert dresidual.shape == x.shape else: dresidual = None - dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd( + dx, dw, db, dresidual_in, dx1, dw1, db1, _ = _layer_norm_bwd( dy, x, weight, @@ -854,8 +977,10 @@ def backward(ctx, dy, *args): rowscale, ctx.has_residual, ctx.has_x1, + ctx.zero_centered_weight, ctx.is_rms_norm, x_dtype=ctx.x_dtype, + recompute_output=False, ) return ( dx.reshape(ctx.x_shape_og), @@ -874,6 +999,8 @@ def backward(ctx, dy, *args): None, None, None, + None, + None, ) @@ -890,8 +1017,10 @@ def layer_norm_fn( rowscale=None, prenorm=False, residual_in_fp32=False, + zero_centered_weight=False, is_rms_norm=False, return_dropout_mask=False, + out_dtype=None, out=None, residual_out=None ): @@ -908,8 +1037,10 @@ def layer_norm_fn( rowscale, prenorm, residual_in_fp32, + zero_centered_weight, is_rms_norm, return_dropout_mask, + out_dtype, out, residual_out ) @@ -928,7 +1059,9 @@ def rms_norm_fn( rowscale=None, prenorm=False, residual_in_fp32=False, + zero_centered_weight=False, return_dropout_mask=False, + out_dtype=None, out=None, residual_out=None ): @@ -945,8 +1078,10 @@ def rms_norm_fn( rowscale, prenorm, residual_in_fp32, + zero_centered_weight, True, return_dropout_mask, + out_dtype, out, residual_out ) @@ -954,7 +1089,8 @@ def rms_norm_fn( class RMSNorm(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None): + def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False, + device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.eps = eps @@ -962,12 +1098,16 @@ def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None self.drop = torch.nn.Dropout(dropout_p) else: self.drop = None + self.zero_centered_weight = zero_centered_weight self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self): - torch.nn.init.ones_(self.weight) + if not self.zero_centered_weight: + torch.nn.init.ones_(self.weight) + else: + torch.nn.init.zeros_(self.weight) def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): return rms_norm_fn( @@ -979,6 +1119,7 @@ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): dropout_p=self.drop.p if self.drop is not None and self.training else 0.0, prenorm=prenorm, residual_in_fp32=residual_in_fp32, + zero_centered_weight=self.zero_centered_weight, ) @@ -1000,17 +1141,12 @@ def forward( ): x_shape_og = x.shape # reshape input data into 2D tensor - x = x.reshape(-1, x.shape[-1]) - if x.stride(-1) != 1: - x = x.contiguous() + x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1])) if residual is not None: assert residual.shape == x_shape_og - residual = residual.reshape(-1, residual.shape[-1]) - if residual.stride(-1) != 1: - residual = residual.contiguous() + residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1])) norm_weight = norm_weight.contiguous() - if norm_bias is not None: - norm_bias = norm_bias.contiguous() + norm_bias = maybe_contiguous(norm_bias) residual_dtype = ( residual.dtype if residual is not None @@ -1049,14 +1185,11 @@ def backward(ctx, dout, *args): dout = dout.reshape(-1, dout.shape[-1]) dy = F.linear(dout, linear_weight.t()) dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) - if dy.stride(-1) != 1: - dy = dy.contiguous() + dy = maybe_contiguous_lastdim(dy) assert dy.shape == x.shape if ctx.prenorm: dresidual = args[0] - dresidual = dresidual.reshape(-1, dresidual.shape[-1]) - if dresidual.stride(-1) != 1: - dresidual = dresidual.contiguous() + dresidual = maybe_contiguous_lastdim(dresidual.reshape(-1, dresidual.shape[-1])) assert dresidual.shape == x.shape else: dresidual = None diff --git a/flash_attn/ops/triton/rotary.py b/flash_attn/ops/triton/rotary.py index 560c75d002d..55e07eff9d0 100644 --- a/flash_attn/ops/triton/rotary.py +++ b/flash_attn/ops/triton/rotary.py @@ -18,7 +18,7 @@ def rotary_kernel( SEQLEN_OFFSETS, # this could be int or a pointer # Matrix dimensions seqlen, - rotary_dim, + nheads, seqlen_ro, # strides stride_out_batch, @@ -30,104 +30,72 @@ def rotary_kernel( stride_x_nheads, stride_x_headdim, # Meta-parameters - BLOCK_K: tl.constexpr, + # We want ROTARY_DIM to be constexpr, otherwise the triton compiler doesn't know that + # the mask is constant every 8 elements, and it will generate LDG.16 instead of LDG.128 + ROTARY_DIM: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, CONJUGATE: tl.constexpr, + BLOCK_H: tl.constexpr, BLOCK_M: tl.constexpr, ): - pid_m = tl.program_id(axis=0) - pid_head = tl.program_id(axis=1) + BLOCK_K: tl.constexpr = triton.next_power_of_2(ROTARY_DIM) + ROTARY_DIM_HALF = ROTARY_DIM // 2 + pid_head = tl.program_id(axis=0) + pid_m = tl.program_id(axis=1) pid_batch = tl.program_id(axis=2) - rotary_dim_half = rotary_dim // 2 if not IS_VARLEN: - X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads - OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + X = X + pid_batch * stride_x_batch + OUT = OUT + pid_batch * stride_out_batch else: start_idx = tl.load(CU_SEQLENS + pid_batch) seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx - X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads - OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + X = X + start_idx * stride_x_seqlen + OUT = OUT + start_idx * stride_out_seqlen if pid_m * BLOCK_M >= seqlen: return + + rh = pid_head * BLOCK_H + tl.arange(0, BLOCK_H) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) if not IS_SEQLEN_OFFSETS_TENSOR: rm_cs = rm + SEQLEN_OFFSETS else: rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) - rk = tl.arange(0, BLOCK_K) + rk_half = tl.arange(0, BLOCK_K // 2) + COS = COS + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) + SIN = SIN + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) + mask_cs = (rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < ROTARY_DIM_HALF) + cos = tl.load(COS, mask=mask_cs, other=1.0).to(tl.float32) + sin = tl.load(SIN, mask=mask_cs, other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin if not INTERLEAVED: # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT - X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim) - COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) - SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) - cos = tl.load( - COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0 - ).to(tl.float32) - sin = tl.load( - SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0 - ).to(tl.float32) - x0 = tl.load( - X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0 - ).to(tl.float32) - x1 = tl.load( - X + rotary_dim_half * stride_x_headdim, - mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), - other=0.0, - ).to(tl.float32) - if CONJUGATE: - sin = -sin + X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk_half[None, None, :] * stride_x_headdim) + OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk_half[None, None, :] * stride_out_headdim) + mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk_half[None, None, :] < ROTARY_DIM_HALF) + x0 = tl.load(X, mask=mask, other=0.0).to(tl.float32) + x1 = tl.load(X + ROTARY_DIM_HALF * stride_x_headdim, mask=mask, other=0.0,).to(tl.float32) o0 = x0 * cos - x1 * sin o1 = x0 * sin + x1 * cos - # write back result - OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim) - tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half)) - tl.store( - OUT + rotary_dim_half * stride_out_headdim, - o1, - mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), - ) + tl.store(OUT, o0, mask=mask) + tl.store(OUT + ROTARY_DIM_HALF * stride_out_headdim, o1, mask=mask) else: - # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow. - # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...]. - # Loading x0 will be fast but x1 will be slow. - # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...]. - # Then we do the calculation and use tl.where to pick put the right outputs for the even - # and for the odd indices. - rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... - rk_repeat = tl.arange(0, BLOCK_K) // 2 - X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim) - X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim) - COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) - SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) - cos = tl.load( - COS, - mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), - other=1.0, - ).to(tl.float32) - sin = tl.load( - SIN, - mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), - other=0.0, - ).to(tl.float32) - x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to( - tl.float32 - ) - x1 = tl.load( - X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0 - ).to(tl.float32) - if CONJUGATE: - sin = -sin - x0_cos = x0 * cos - x1_sin = x1 * sin - out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) - OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim) - tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim)) + rk = tl.arange(0, BLOCK_K) + X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk[None, None, :] * stride_x_headdim) + OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk[None, None, :] * stride_out_headdim) + mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk[None, None, :] < ROTARY_DIM) + x = tl.load(X, mask=mask, other=0.0).to(tl.float32) + x0, x1 = tl.split(tl.reshape(x, [BLOCK_H, BLOCK_M, BLOCK_K // 2, 2])) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + o = tl.reshape(tl.join(o0, o1), [BLOCK_H, BLOCK_M, BLOCK_K]) + tl.store(OUT, o, mask=mask) def apply_rotary( @@ -169,13 +137,6 @@ def apply_rotary( assert headdim <= 256, "Only support headdim <= 256" assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" - assert ( - cos.dtype == sin.dtype - ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" - assert ( - x.dtype == cos.dtype - ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" - cos, sin = cos.contiguous(), sin.contiguous() if isinstance(seqlen_offsets, torch.Tensor): assert seqlen_offsets.shape == (batch,) @@ -188,18 +149,13 @@ def apply_rotary( if rotary_dim < headdim and not inplace: output[..., rotary_dim:].copy_(x[..., rotary_dim:]) - BLOCK_K = ( - 32 - if rotary_dim <= 32 - else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) - ) - grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), nheads, batch) # noqa - BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 128 else 4) + grid = lambda META: (triton.cdiv(nheads, META["BLOCK_H"]), triton.cdiv(seqlen, META["BLOCK_M"]), batch) # noqa + BLOCK_M = 8 if rotary_dim <= 128 else 4 # Need this, otherwise Triton tries to launch from cuda:0 and we get # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) with torch.cuda.device(x.device.index): - rotary_kernel[grid]( + torch.library.wrap_triton(rotary_kernel)[grid]( output, # data ptrs x, cos, @@ -207,7 +163,7 @@ def apply_rotary( cu_seqlens, seqlen_offsets, seqlen, # shapes - rotary_dim, + nheads, seqlen_ro, output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 output.stride(-3), # seqlen_stride or total_seqlen_stride @@ -217,12 +173,12 @@ def apply_rotary( x.stride(-3), # seqlen stride or total_seqlen_stride x.stride(-2), # nheads stride x.stride(-1), # headdim stride - BLOCK_K, + rotary_dim, isinstance(seqlen_offsets, torch.Tensor), is_varlen, interleaved, conjugate, - BLOCK_M, - num_warps=2 if rotary_dim <= 64 else 4, + BLOCK_M=BLOCK_M, + BLOCK_H=2, ) return output diff --git a/flash_attn/pyproject.toml b/flash_attn/pyproject.toml index 3201555763e..ce5eac916cd 100644 --- a/flash_attn/pyproject.toml +++ b/flash_attn/pyproject.toml @@ -1,3 +1,6 @@ [tool.black] line-length = 100 -target-version = ['py38'] \ No newline at end of file +target-version = 'py39' +[tool.ruff] +line-length = 100 +target-version = 'py39' \ No newline at end of file diff --git a/flash_attn/utils/library.py b/flash_attn/utils/library.py new file mode 100644 index 00000000000..05324bb01a4 --- /dev/null +++ b/flash_attn/utils/library.py @@ -0,0 +1,66 @@ +# Adapted from https://github.com/pytorch/pytorch/blob/v2.7.0/torch/_library/triton.py +# The PyTorch implementation simply ignores the schema argument, we simply modify it to use schema. + +from typing import Optional, Callable, Iterable, Union + +from torch.library import custom_op, CustomOpDef +from torch._library.triton import set_wrap_triton_enabled + + +def triton_op( + name: str, + fn: Optional[Callable] = None, + /, + *, + mutates_args: Union[str, Iterable[str]], + schema: Optional[str] = None, + # If allow_decomposition=True, this matches torch.library.triton_op behavior. If set to False, + # then it behaves like torch.library.custom_op instead, which doesn't decompose the operator + # and so inductor can't trace inside. + allow_decomposition=True, +) -> Callable: + def dec(fn: Callable[..., object]) -> CustomOpDef: + def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def] + # Optimization: we're passing regular Tensors into the triton kernel, so + # no need to go through HOP dispatch + with set_wrap_triton_enabled(False): + return fn(*args, **kwargs) + + result = custom_op( + name, + backend_fn, + mutates_args=mutates_args, + # This is the only difference with the PyTorch implementation + schema=schema, + ) + from torch._subclasses.functional_tensor import FunctionalTensorMode + + # We require that the user pass us a function that is make_fx traceable, + # so we can just register it as the Fake/meta kernel. + result.register_fake(fn) + + if allow_decomposition: + # We decompose the operator when FunctionalTensorMode is active. + # The goal is to decompose the operator in AOTDispatcher. + # - With torch.compile, this means that the backend (usually Inductor) + # can see a call to the triton kernel(s) and so it can directly optimize + # them by inlining them into the lowering process. + def functional_decomp( # type: ignore[no-untyped-def] + mode, op, types, args, kwargs + ): + from torch.export._trace import custom_triton_ops_decomposition_disabled + + if custom_triton_ops_decomposition_disabled(): + return mode.__torch_dispatch__(op, types, args, kwargs) + else: + with mode: + return fn(*args, **kwargs) + + result.register_torch_dispatch(FunctionalTensorMode, functional_decomp) + + return result + + if fn is None: + return dec + else: + return dec(fn) diff --git a/flash_attn/utils/testing.py b/flash_attn/utils/testing.py new file mode 100644 index 00000000000..81be51f1de8 --- /dev/null +++ b/flash_attn/utils/testing.py @@ -0,0 +1,360 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +import math +from typing import Optional + +import torch +from einops import rearrange, repeat + +from flash_attn.bert_padding import pad_input, unpad_input + + +def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False): + assert mode in ["full", "random", "third"] + if mode == "full": + lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) + elif mode == "random": + lengths = torch.randint( + max(0 if zero_lengths else 1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device + ) + elif mode == "third": + lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) + + if zero_lengths: + # Generate zero-lengths every 5 batches and the last batch. + for i in range(batch_size): + if i % 5 == 0: + lengths[i] = 0 + lengths[-1] = 0 + padding_mask = ( + repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths + ) + return padding_mask + + +def generate_qkv( + q, k, v, query_padding_mask=None, key_padding_mask=None, qv=None, kvpacked=False, qkvpacked=False, + query_unused_mask=None, key_unused_mask=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, d) + k: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d_v) + query_padding_mask: (batch_size, seqlen), bool + key_padding_mask: (batch_size, seqlen), bool + """ + assert not (kvpacked and qkvpacked) + batch_size, seqlen_q, nheads, d = q.shape + d_v = v.shape[-1] + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d_v) + if query_unused_mask is not None or key_unused_mask is not None: + assert not kvpacked + assert not qkvpacked + + if query_padding_mask is not None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input( + q, query_padding_mask, query_unused_mask + ) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None + else: + q_unpad = rearrange(q, "b s h d -> (b s) h d") + cu_seqlens_q = torch.arange( + 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device + ) + seqused_q = None + max_seqlen_q = seqlen_q + output_pad_fn = lambda output_unpad: rearrange( + output_unpad, "(b s) h d -> b s h d", b=batch_size + ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None + + if key_padding_mask is not None: + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input( + k, key_padding_mask, key_unused_mask + ) + v_unpad, *rest = unpad_input(v, key_padding_mask, key_unused_mask) + else: + k_unpad = rearrange(k, "b s h d -> (b s) h d") + v_unpad = rearrange(v, "b s h d -> (b s) h d") + cu_seqlens_k = torch.arange( + 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device + ) + seqused_k = None + max_seqlen_k = seqlen_k + + if qkvpacked: + assert (query_padding_mask == key_padding_mask).all() + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + if query_padding_mask is not None: + dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) + else: + dqkv_pad_fn = lambda dqkv_unpad: rearrange( + dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + qkv_unpad.detach().requires_grad_(), + cu_seqlens_q, + max_seqlen_q, + qkv.detach().requires_grad_(), + output_pad_fn, + dqkv_pad_fn, + ) + elif kvpacked: + kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) + kv = torch.stack([k, v], dim=2) + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) + else: + dkv_pad_fn = lambda dkv_unpad: rearrange( + dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + q_unpad.detach().requires_grad_(), + kv_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + kv.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dkv_pad_fn, + ) + else: + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k) + else: + dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) + return ( + q_unpad.detach().requires_grad_(), + k_unpad.detach().requires_grad_(), + v_unpad.detach().requires_grad_(), + qv_unpad.detach() if qv is not None else None, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + k.detach().requires_grad_(), + v.detach().requires_grad_(), + qv.detach() if qv is not None else None, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) + + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(None, None), + sink_token_length=0, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + if window_size[0] is None: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + torch.logical_and(col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length), + ) + + +def construct_chunk_mask( + seqlen_q, + seqlen_k, + attention_chunk, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + # Subtract remainder instead of divide and then multiply to take care of negative values + col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk + return torch.logical_or( + col_idx < col_limit_left_chunk, col_idx >= col_limit_left_chunk + attention_chunk + ) + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + qv=None, + q_descale=None, k_descale=None, v_descale=None, + window_size=(None, None), + attention_chunk=0, + sink_token_length=0, + learnable_sink: Optional[torch.Tensor] = None, + softcap=0.0, + upcast=True, + reorder_ops=False, + intermediate_dtype=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads, head_dim) + v: (batch_size, seqlen_k, nheads, head_dim_v) + qv: (batch_size, seqlen_q, nheads, head_dim_v) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim_v) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + qv = qv.float() if qv is not None else None + if q_descale is not None: + q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2]) + q = (q.float() * q_descale).to(q.dtype) + qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None + if k_descale is not None: + k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype) + if v_descale is not None: + v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype) + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + dv = v.shape[-1] + softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv) + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if qv is not None: + scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v) + if softcap > 0: + scores = torch.tanh(scores / softcap) * softcap + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + local_mask = None + if window_size[0] is not None or window_size[1] is not None: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + sink_token_length, + query_padding_mask, + key_padding_mask, + key_leftpad=key_leftpad, + device=q.device, + ) + if attention_chunk > 0: + chunk_mask = construct_chunk_mask( + seqlen_q, + seqlen_k, + attention_chunk, + query_padding_mask, + key_padding_mask, + key_leftpad=key_leftpad, + device=q.device, + ) + local_mask = torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask + if local_mask is not None: + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias + if learnable_sink is None: + attention = torch.softmax(scores, dim=-1).to(v.dtype) + else: + scores_fp32 = scores.to(torch.float32) + logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True) + learnable_sink = rearrange(learnable_sink, "h -> h 1 1") + logits_or_sinks_max = torch.maximum(learnable_sink, logits_max) + unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp(learnable_sink - logits_or_sinks_max) + attention = (unnormalized_scores / normalizer).to(v.dtype) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + # Without this we might get NaN in dv + if key_padding_mask is not None: + attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if local_mask is not None: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + dropout_scaling = 1.0 / (1 - dropout_p) + # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling + # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + if intermediate_dtype is not None: + attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) diff --git a/hopper/benchmark_attn.py b/hopper/benchmark_attn.py index 33e5d282716..e94d325d42d 100644 --- a/hopper/benchmark_attn.py +++ b/hopper/benchmark_attn.py @@ -68,7 +68,7 @@ def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, w else: row_idx = torch.arange(seqlen_q, device='cuda') col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) - col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1)) + col_right = torch.minimum(row_idx + seqlen_k - seqlen_q + window_size[1], torch.tensor(seqlen_k - 1)) avg_seqlen = (col_right - col_left + 1).float().mean().item() return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) diff --git a/hopper/block.h b/hopper/block.h index 03ac38476ae..b07b414ca45 100644 --- a/hopper/block.h +++ b/hopper/block.h @@ -15,6 +15,7 @@ struct BlockMN { SeqlenInfo_t const& seqlen_info, int const m_block, int const bidb, int const split_idx, int const num_splits, int const window_size_left, int const window_size_right, + cutlass::FastDivmod const& attention_chunk_divmod, cutlass::FastDivmod const& qhead_per_khead_divmod) { int seqlen_k = seqlen_info.seqlen_k; @@ -25,8 +26,13 @@ struct BlockMN { if constexpr (Is_local) { int m_idx_min = m_block * kBlockM; if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); } + int const n_idx = m_idx_min + seqlen_k - seqlen_q; // unlike previously, we don't divide by kBlockN because we want offset for seqlen_k - n_offset = std::max(int(0), m_idx_min + seqlen_k - seqlen_q - window_size_left); + int n_idx_left = n_idx - window_size_left; + if (attention_chunk_divmod.divisor > 0) { + n_idx_left = std::max(n_idx_left, flash::round_down(attention_chunk_divmod, n_idx)); + } + n_offset = std::max(int(0), n_idx_left); // Subtract n_offset from seqlen_k for subsequent calculations such as n_block_max // This is the actual seqlen_k processed for this m_block seqlen_k -= n_offset; @@ -42,6 +48,10 @@ struct BlockMN { // cp_world_size is guaranteed to be greater than 0 int tot_seqlen_k = (Is_local) ? seqlen_k : seqlen_info.tot_seqlen_k; int n_token_max = m_idx_max + tot_seqlen_k - seqlen_q + window_size_right; + if (Is_local && attention_chunk_divmod.divisor > 0) { + int const n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q; + n_token_max = std::min(n_token_max, flash::round_up(attention_chunk_divmod, n_idx)); + } if (seqlen_info.cp_world_size > 1 && !Is_local) { n_token_max = cute::ceil_div(n_token_max - seqlen_info.cp_rank, seqlen_info.cp_world_size); } @@ -73,11 +83,12 @@ struct BlockMN { SeqlenInfo_t const& seqlen_info, int const m_block, int const bidb, int const split_idx, int const num_splits, int const window_size_left, int const window_size_right, + cutlass::FastDivmod const& attention_chunk_divmod, cutlass::FastDivmod const& qhead_per_khead_divmod) { // TODO: check logic with n_offset auto [n_block_min, n_block_max, n_offset] = get_n_block_min_max( seqlen_info, m_block, bidb, split_idx, num_splits, - window_size_left, window_size_right, qhead_per_khead_divmod); + window_size_left, window_size_right, attention_chunk_divmod, qhead_per_khead_divmod); int const idx_k_new_min = std::max(n_block_min * kBlockN + n_offset - seqlen_info.seqlen_k_og, 0); int const idx_k_new_max = std::min(n_block_max * kBlockN + n_offset - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new); int const n_block_new_min = idx_k_new_min / kBlockN; @@ -108,6 +119,40 @@ struct BlockMN { return {m_block_min, m_block_max}; } + // If we have separate iterations with causal or local masking at the start, where do we stop + static + CUTLASS_DEVICE + int get_n_block_min_causal_local_mask( + SeqlenInfo_t const& seqlen_info, + int const m_block, int const n_block_min, int const window_size_right, + cutlass::FastDivmod const& attention_chunk_divmod, + cutlass::FastDivmod const& qhead_per_khead_divmod) { + int const m_idx_min = !PackGQA ? m_block * kBlockM : qhead_per_khead_divmod.divide(m_block * kBlockM); + int const n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q; + int n_idx_right = !Is_local ? n_idx : n_idx + window_size_right; + if (Is_local && attention_chunk_divmod.divisor > 0) { + n_idx_right = std::min(n_idx_right, flash::round_up(attention_chunk_divmod, n_idx)); + } + return std::max(n_block_min, n_idx_right / kBlockN); + } + + // If we have separate iterations with local masking at the end, where do we stop the non-masked iterations + static + CUTLASS_DEVICE + int get_n_block_min_before_local_mask( + SeqlenInfo_t const& seqlen_info, + int const m_block, int const n_block_min, int const window_size_left, + cutlass::FastDivmod const& attention_chunk_divmod, + cutlass::FastDivmod const& qhead_per_khead_divmod) { + int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1; + int const n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q; + int n_idx_left = !Is_local ? n_idx : n_idx - window_size_left; + if (Is_local && attention_chunk_divmod.divisor > 0) { + n_idx_left = std::max(n_idx_left, flash::round_down(attention_chunk_divmod, n_idx)); + } + return !Is_local ? n_block_min : std::max(n_block_min, cute::ceil_div(n_idx_left, kBlockN)); + } + }; } // namespace flash diff --git a/hopper/epilogue_bwd.hpp b/hopper/epilogue_bwd.hpp index 9362b040453..fdae7616683 100644 --- a/hopper/epilogue_bwd.hpp +++ b/hopper/epilogue_bwd.hpp @@ -107,7 +107,9 @@ struct CollectiveEpilogueBwd { ShapedKV const shape_dK; StridedKV const stride_dK; Element* ptr_dV; + ShapedKV const shape_dV; StridedKV const stride_dV; + int const num_batch; int const num_heads_q; int* dk_semaphore; int* dv_semaphore; @@ -121,6 +123,7 @@ struct CollectiveEpilogueBwd { ShapedKV const shape_dK; StridedKV const stride_dK; Element* ptr_dV; + ShapedKV const shape_dV; StridedKV const stride_dV; TMA_dKV tma_store_dK, tma_store_dV; int const* cu_seqlens = nullptr; @@ -130,7 +133,7 @@ struct CollectiveEpilogueBwd { static Params to_underlying_arguments(Arguments const& args) { Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK); - Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dK, args.stride_dV); + Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dV, args.stride_dV); TMA_dKV tma_store_dK = [&] { if constexpr (Use_TMA) { return make_tma_copy(GmemTiledCopydKVTMA{}, mdK, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV @@ -145,7 +148,7 @@ struct CollectiveEpilogueBwd { return nullptr; } }(); - return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.stride_dV, + return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.shape_dV, args.stride_dV, tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused}; } @@ -197,7 +200,7 @@ struct CollectiveEpilogueBwd { cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); Tensor mdK = params.tma_store_dK.get_tma_tensor(params.shape_dK); - Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dK); + Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dV); Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) auto block_tma_dK = params.tma_store_dK.get_slice(_0{}); @@ -227,7 +230,7 @@ struct CollectiveEpilogueBwd { bool const is_varlen = Varlen && params.cu_seqlens; Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) - Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); + Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dV, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) GmemTiledCopydKV gmem_tiled_copy_dKV; @@ -241,25 +244,28 @@ struct CollectiveEpilogueBwd { Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); - Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKVgdV))); + Tensor tdKVpdV = make_tensor(make_shape(size<2>(tdKVgdV))); + Tensor tdKVpdK = make_tensor(make_shape(size<2>(tdKVgdK))); #pragma unroll - for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); } + for (int k = 0; k < size(tdKVpdV); ++k) { tdKVpdV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dV); } + #pragma unroll + for (int k = 0; k < size(tdKVpdK); ++k) { tdKVpdK(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); } // Need to check OOB when reading from smem if kBlockN isn't evenly tiled static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0; flash::copy( - gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV, tdKVcdKV, tdKVpdKV, kBlockN); + gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV, tdKVcdKV, tdKVpdV, kBlockN); flash::copy( - gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK, tdKVcdKV, tdKVpdKV, kBlockN); + gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK, tdKVcdKV, tdKVpdK, kBlockN); // // Tell warp 0 that smem_k and smem_v are ready // cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_k/v // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); // Construct identity layout for gdKV // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( - gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdKV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) + gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) ); flash::copy( - gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdKV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) + gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdK, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) ); } } @@ -282,7 +288,7 @@ struct CollectiveEpilogueBwd { bool const is_varlen = Varlen && params.cu_seqlens; Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) - Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); + Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dV, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) GmemTiledCopydKV gmem_tiled_copy_dKV; @@ -295,15 +301,18 @@ struct CollectiveEpilogueBwd { Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); - Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKVgdK))); + Tensor tdKVpdK = make_tensor(make_shape(size<2>(tdKVgdK))); + Tensor tdKVpdV = make_tensor(make_shape(size<2>(tdKVgdV))); + #pragma unroll + for (int k = 0; k < size(tdKVpdK); ++k) { tdKVpdK(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); } #pragma unroll - for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); } + for (int k = 0; k < size(tdKVpdV); ++k) { tdKVpdV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dV); } // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( - gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN + gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdK, seqlen_info.seqlen - n_block * kBlockN ); flash::copy( - gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN + gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdV, seqlen_info.seqlen - n_block * kBlockN ); } @@ -359,8 +368,10 @@ struct CollectiveEpilogueBwdGQA { ShapedKV const shape_dKaccum; StridedKV const stride_dKaccum; ElementAccum* ptr_dVaccum; + ShapedKV const shape_dVaccum; StridedKV const stride_dVaccum; - int num_heads_q; + int const num_batch; + int const num_heads_q; int* dk_semaphore; int* dv_semaphore; int const* cu_seqlens; @@ -373,10 +384,12 @@ struct CollectiveEpilogueBwdGQA { ShapedKV const shape_dKaccum; StridedKV const stride_dKaccum; ElementAccum* ptr_dVaccum; + ShapedKV const shape_dVaccum; StridedKV const stride_dVaccum; cutlass::FastDivmod qhead_per_khead_divmod; int* dk_semaphore; int* dv_semaphore; + int const num_batch; int const* cu_seqlens = nullptr; int const* seqused = nullptr; }; @@ -387,10 +400,10 @@ struct CollectiveEpilogueBwdGQA { assert(args.dk_semaphore != nullptr); assert(args.dv_semaphore != nullptr); } - return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.stride_dVaccum, + return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.shape_dVaccum, args.stride_dVaccum, cutlass::FastDivmod(cute::ceil_div(args.num_heads_q, get<1>(args.shape_dKaccum))), args.dk_semaphore, args.dv_semaphore, - args.cu_seqlens, args.seqused}; + args.num_batch, args.cu_seqlens, args.seqused}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance @@ -419,7 +432,7 @@ struct CollectiveEpilogueBwdGQA { flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_dKaccum), params.cu_seqlens, params.seqused}; bool const is_varlen = Varlen && params.cu_seqlens; Tensor mdKaccum = make_tensor(make_gmem_ptr(params.ptr_dKaccum), params.shape_dKaccum, params.stride_dKaccum)(_, bidh_kv, !is_varlen ? bidb : 0); - Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dKaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0); + Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dVaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0); Tensor gdKaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdKaccum), Shape>{}, make_coord(n_block)); // (M * K) Tensor gdVaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdVaccum), Shape>{}, make_coord(n_block)); // (M * K) @@ -439,8 +452,8 @@ struct CollectiveEpilogueBwdGQA { cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdV, tdKVsdKVaccum); } - // int const num_batch = params.num_batch; - int const num_batch = get<2>(params.shape_dKaccum); + int const num_batch = params.num_batch; + // int const num_batch = get<2>(params.shape_dKaccum); // erroneously returns 1 for varlen int const num_head_kv = get<1>(params.shape_dKaccum); int *lock_ptr = !Deterministic ? nullptr : params.dv_semaphore + bidb * num_head_kv + bidh_kv; using Barrier = cutlass::GenericBarrier; diff --git a/hopper/flash.h b/hopper/flash.h index e5c6d6238f6..c3cf4a742e9 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -135,6 +135,7 @@ struct Flash_fwd_params : public Qkv_params { // Local window size int window_size_left, window_size_right; + int attention_chunk; // Pointer to the RNG seed (idx 0) and offset (idx 1). uint64_t * rng_state; @@ -151,8 +152,7 @@ struct Flash_fwd_params : public Qkv_params { bool pack_gqa; int * __restrict__ tile_count_semaphore; - // int * __restrict__ num_m_blocks_ptr; - int * __restrict__ prepare_seqlen_q_ptr; + int * __restrict__ num_m_blocks_ptr; // int * __restrict__ num_n_blocks_ptr; int * __restrict__ num_splits_dynamic_ptr; int * __restrict__ varlen_batch_idx_ptr; // virtual -> actual diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 9e0baffcd2a..e5624fa7f31 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -2,10 +2,9 @@ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ -// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. -#include -#include // For TORCH_VERSION* macros -#include +#include +#include +#include #include #include @@ -16,44 +15,25 @@ #include "heuristics.h" #include "cuda_check.h" -// Copied from https://github.com/pytorch/pytorch/commit/7931eee5c5ebcdf468bff4d308510b03355cd909 -// This is so that we can pass in torch.dtype as a parameter to the function. -#if TORCH_VERSION_MAJOR < 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR < 4) - -#include -#include - -namespace pybind11::detail { - - template <> - struct type_caster { - public: - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - PYBIND11_TYPE_CASTER(at::ScalarType, _("torch.dtype")); - // PYBIND11_TYPE_CASTER defines a member field called value. at::ScalarType - // cannot be default-initialized, we provide this constructor to explicitly - // initialize that field. The value doesn't matter as it will be overwritten - // after a successful call to load. - type_caster() : value(at::kFloat) {} - bool load(handle src, bool) { - PyObject* obj = src.ptr(); - if (THPDtype_Check(obj)) { - value = reinterpret_cast(obj)->scalar_type; - return true; - } - return false; - } - static handle cast( - const at::ScalarType& src, - return_value_policy /* policy */, - handle /* parent */) { - return Py_NewRef(torch::getTHPDtype(src)); - } - }; - -} // namespace pybind11::detail -#endif +extern "C" { +/* Creates a dummy empty _C module that can be imported from Python. + The import from Python will load the .so consisting of this file + in this extension, so that the TORCH_LIBRARY static initializers + below are run. */ +PyObject* PyInit__C(void) +{ + static struct PyModuleDef module_def = { + PyModuleDef_HEAD_INIT, + "_C", /* name of module */ + NULL, /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + NULL, /* methods */ + }; + return PyModule_Create(&module_def); +} +} #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") @@ -92,6 +72,7 @@ void set_params_fprop(Flash_fwd_params ¶ms, float softmax_scale, int window_size_left, int window_size_right, + int attention_chunk, const float softcap=0.f, const int sm_margin=0) { @@ -164,14 +145,19 @@ void set_params_fprop(Flash_fwd_params ¶ms, // Causal is the special case where window_size_right == 0 and window_size_left < 0. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. - params.is_causal = window_size_left < 0 && window_size_right == 0; - params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal; + params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0; + params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal; // TODO: check this - if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k - 1; } - if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_q - 1; } + if (window_size_left < 0) { window_size_left = seqlen_k - 1; } + if (window_size_right < 0) { window_size_right = seqlen_q - 1; } + if (attention_chunk > 0) { + window_size_left = std::min(window_size_left, attention_chunk - 1); + window_size_right = std::min(window_size_right, attention_chunk - 1); + } params.window_size_left = window_size_left; params.window_size_right = window_size_right; + params.attention_chunk = attention_chunk; params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin; @@ -214,6 +200,7 @@ void set_params_dgrad(Flash_bwd_params ¶ms, float softmax_scale, int window_size_left, int window_size_right, + int attention_chunk, const float softcap=0.f, bool deterministic=false, int const sm_margin=0) { @@ -230,6 +217,7 @@ void set_params_dgrad(Flash_bwd_params ¶ms, softmax_scale, window_size_left, window_size_right, + attention_chunk, softcap, sm_margin); @@ -264,6 +252,118 @@ void set_params_dgrad(Flash_bwd_params ¶ms, params.deterministic = deterministic; } +template +void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { + if (!params.is_e4m3) { + if (params.is_bf16) { + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d <= 64) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 + if constexpr (Arch == 90) { + if (params.dv > 256) { + return run_mha_fwd_(params, stream); + } else if (params.dv > 64) { + return run_mha_fwd_(params, stream); + } + } + #endif + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 + if constexpr (Arch == 90) { + if (params.dv <= 128) { + return run_mha_fwd_(params, stream); + } + } + #endif + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_fwd_(params, stream); } + #endif + } else { + #ifndef FLASHATTENTION_DISABLE_FP16 + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d <= 64) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 + if constexpr (Arch == 90) { + if (params.dv > 256) { + return run_mha_fwd_(params, stream); + } else if (params.dv > 64) { + return run_mha_fwd_(params, stream); + } + } + #endif + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 + if constexpr (Arch == 90) { + if (params.dv <= 128) { + return run_mha_fwd_(params, stream); + } + } + #endif + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_fwd_(params, stream); } + #endif + #else + TORCH_CHECK(false, "This flash attention build does not support FP16."); + #endif + } + } else { + #ifndef FLASHATTENTION_DISABLE_FP8 + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 + if constexpr (Arch == 90) { + if (params.dv <= 128) { + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA_HDIMDIFF>(params, stream); + } + } + #endif + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #else + TORCH_CHECK(false, "This flash attention build does not support FP8."); + #endif + } +} + void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // HEADDIM_SWITCH(params.d, [&] { // run_mha_fwd_(params, stream); @@ -283,119 +383,7 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr bool PackGQA_HDIMDIFF = PackGQA_ || Arch < 90 || PagedKVNonTMA || Split; #endif SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] { - if (!params.is_e4m3) { - if (params.is_bf16) { - #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { - #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 - if (params.dv > 256 && Arch == 90) { - return run_mha_fwd_(params, stream); - } else if (params.dv > 64 && Arch == 90) { - return run_mha_fwd_(params, stream); - } else { - return run_mha_fwd_(params, stream); - } - #else - return run_mha_fwd_(params, stream); - #endif - } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { - #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 - if (params.dv <= 128 && Arch == 90) { - return run_mha_fwd_(params, stream); - } else { - return run_mha_fwd_(params, stream); - } - #else - return run_mha_fwd_(params, stream); - #endif - } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_(params, stream); } - #endif - } else { - #ifndef FLASHATTENTION_DISABLE_FP16 - #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { - #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 - if (params.dv > 256 && Arch == 90) { - return run_mha_fwd_(params, stream); - } else if (params.dv > 64 && Arch == 90) { - return run_mha_fwd_(params, stream); - } else { - return run_mha_fwd_(params, stream); - } - #else - return run_mha_fwd_(params, stream); - #endif - } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { - #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 - if (params.dv <= 128 && Arch == 90) { - return run_mha_fwd_(params, stream); - } else { - return run_mha_fwd_(params, stream); - } - #else - return run_mha_fwd_(params, stream); - #endif - } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_(params, stream); } - #endif - #else - TORCH_CHECK(false, "This flash attention build does not support FP16."); - #endif - } - } else { - #ifndef FLASHATTENTION_DISABLE_FP8 - #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { - #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 - if (params.dv <= 128 && Arch == 90) { - return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA_HDIMDIFF>(params, stream); - } else { - return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); - } - #else - return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); - #endif - } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } - #endif - #else - TORCH_CHECK(false, "This flash attention build does not support FP8."); - #endif - } + run_mha_fwd_constexpr(params, stream); }); }); }); @@ -459,7 +447,7 @@ inline bool get_pack_gqa(Flash_fwd_params const& params) { // params.page_table must already be set if (params.h == params.h_k) { return false; } // This needs to match the kernel configs - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, use_one_mma_wg(params)); + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM); #endif @@ -552,30 +540,30 @@ inline int round_up_headdimv(int head_size) { // Only applicable to the case where seqused_k (i.e. cache_seqlens) is available at::Tensor mha_fwd_get_scheduler_metadata( - int batch_size, - int max_seqlen_q, - int max_seqlen_k, - int num_heads, - int num_heads_k, - int headdim, - int headdim_v, + int64_t batch_size, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + int64_t num_heads, + int64_t num_heads_k, + int64_t headdim, + int64_t headdim_v, at::ScalarType qkv_dtype, - const at::Tensor &seqused_k, // b - std::optional &cu_seqlens_q_, // b+1 - std::optional &cu_seqlens_k_, // b+1 - std::optional &cu_seqlens_k_new_, // b+1 - std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional &leftpad_k_, // b - std::optional page_size, - int max_seqlen_k_new, // 0 means we're not appending new KV + at::Tensor seqused_k, // b + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional cu_seqlens_k_new_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional leftpad_k_, // b + std::optional page_size, + int64_t max_seqlen_k_new, // 0 means we're not appending new KV bool is_causal, - int window_size_left, - int window_size_right, + int64_t window_size_left, + int64_t window_size_right, + int64_t attention_chunk, bool has_softcap, - int num_splits, + int64_t num_splits, std::optional pack_gqa_, - int const sm_margin - ) { + int64_t sm_margin) { TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn, "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); @@ -608,7 +596,7 @@ mha_fwd_get_scheduler_metadata( if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; } if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; } // causal=true is the same as causal=false in this case - if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1) { + if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) { // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA if ((headdim <= 64 || headdim > 128) || !page_size.has_value()) { is_causal = false; @@ -616,12 +604,17 @@ mha_fwd_get_scheduler_metadata( } if (is_causal) { window_size_right = 0; } - params.is_causal = window_size_left < 0 && window_size_right == 0; - params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal; - if (window_size_left < 0 && window_size_right >= 0) { window_size_left = max_seqlen_k - 1; } - if (window_size_left >= 0 && window_size_right < 0) { window_size_right = max_seqlen_q - 1; } + params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0; + params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal; + if (window_size_left < 0) { window_size_left = max_seqlen_k - 1; } + if (window_size_right < 0) { window_size_right = max_seqlen_q - 1; } + if (attention_chunk > 0) { + window_size_left = std::min(window_size_left, attention_chunk - 1); + window_size_right = std::min(window_size_right, attention_chunk - 1); + } params.window_size_left = window_size_left; params.window_size_right = window_size_right; + params.attention_chunk = attention_chunk; params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin; params.softcap = has_softcap ? 1.0f : 0.0f; @@ -631,20 +624,11 @@ mha_fwd_get_scheduler_metadata( bool const use_prepare_varlen = true; params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA; - // set to use in split heuristic - params.num_splits_dynamic_ptr = reinterpret_cast(1); + params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); params.pagedkv_tma = get_pagedkv_tma(params); params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); - // Always enable PackGQA for Split - params.pack_gqa |= params.num_splits > 1; - // printf("Num splits (metadata) = %d.\n", params.num_splits); - #ifdef FLASHATTENTION_PACKGQA_ONLY - params.pack_gqa |= params.d == params.dv; - #endif - - bool const use_dynamic_split = params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA && params.num_splits > 1; bool is_varlen = true; @@ -657,26 +641,23 @@ mha_fwd_get_scheduler_metadata( at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - // params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template - params.varlen_sort_batches = false; + params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template if (scheduler_needs_semaphore || use_prepare_varlen) { int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers - int num_prepare_batch_vectors = use_prepare_varlen ? 1 : 0; - if (use_dynamic_split) { num_prepare_batch_vectors += 1; } - if (params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } - if (params.head_swizzle) { num_prepare_batch_vectors += 1; } - int sort_offset = b_rounded * (use_dynamic_split ? 2 : 1); - int head_swizzle_offset = b_rounded * (num_prepare_batch_vectors - 1); + int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0; + if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } + if(params.head_swizzle) { num_prepare_batch_vectors += 1; } + int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2); int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; // printf("(Metadata) num prepare batch vectors = %d.\n", num_prepare_batch_vectors); tile_count_semaphore = torch::empty( {int(scheduler_needs_semaphore) + tile_count_semaphore_offset}, opts.dtype(torch::kInt32)); - // ORDER: {prepare_seqlen_q, num_splits_dynamic, varlen_batch_idx, num_nheads_in_l2} - params.prepare_seqlen_q_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() : nullptr; - params.num_splits_dynamic_ptr = use_prepare_varlen && use_dynamic_split ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; - params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr() + sort_offset : nullptr; + // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} + params.num_splits_dynamic_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() : nullptr; + params.num_m_blocks_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; + params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; if (scheduler_needs_semaphore) { @@ -688,7 +669,7 @@ mha_fwd_get_scheduler_metadata( } if (use_prepare_varlen) { - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, use_one_mma_wg(params)); + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr); int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); @@ -707,45 +688,46 @@ mha_fwd_get_scheduler_metadata( // h: num_heads // h_k: num_heads_k // d: head_size -std::vector -mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor &k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. - const at::Tensor &v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. - std::optional &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - std::optional &v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new - std::optional &q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q - std::optional &out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q - std::optional &cu_seqlens_q_, // b+1 - std::optional &cu_seqlens_k_, // b+1 - std::optional &cu_seqlens_k_new_, // b+1 - std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional &seqused_k_, // b. If given, only this many elements of each batch element's keys are used. - std::optional max_seqlen_q_, +std::tuple +mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + at::Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. + at::Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. + std::optional k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new + std::optional v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + std::optional out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional cu_seqlens_k_new_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + std::optional max_seqlen_q_, // TODO: check if we need max_seqlen_k - std::optional max_seqlen_k_, - std::optional &page_table_, // (b_k, max_num_pages_per_seq) - std::optional &kv_batch_idx_, // b. indices to index into the KV cache - std::optional &leftpad_k_, // b - std::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) - std::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) - std::optional &seqlens_rotary_, // b - std::optional &q_descale_, // (b, h_k), not (b, h) - std::optional &k_descale_, // (b, h_k) - std::optional &v_descale_, // (b, h_k) - float const softmax_scale, + std::optional max_seqlen_k_, + std::optional page_table_, // (b_k, max_num_pages_per_seq) + std::optional kv_batch_idx_, // b. indices to index into the KV cache + std::optional leftpad_k_, // b + std::optional rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional seqlens_rotary_, // b + std::optional q_descale_, // (b, h_k), not (b, h) + std::optional k_descale_, // (b, h_k) + std::optional v_descale_, // (b, h_k) + std::optional softmax_scale_, bool is_causal, - int window_size_left, - int window_size_right, - float const softcap, - bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 - std::optional &scheduler_metadata_, // (b + 1) - int num_splits, + int64_t window_size_left, + int64_t window_size_right, + int64_t attention_chunk, + double softcap, + bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + std::optional scheduler_metadata_, // (b + 1) + int64_t num_splits, std::optional pack_gqa_, - int const sm_margin, - std::optional &s_aux_, // (h) - int const cp_world_size, // context parallelism (cp) world size - int const cp_rank, // cp rank - std::optional &cp_tot_seqused_k_ // b. total seqused_k in cp world + int64_t sm_margin, + std::optional s_aux_, + int64_t cp_world_size, + int64_t cp_rank, + std::optional cp_tot_seqused_k_ ) { auto dprops = at::cuda::getCurrentDeviceProperties(); @@ -810,6 +792,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); int const num_heads_k = k.size(-2); int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0); + double softmax_scale = 1.0 / sqrt(double(head_size)); + if (softmax_scale_.has_value()) { + softmax_scale = softmax_scale_.value(); + } if (!kv_batch_idx_.has_value()) { TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); } @@ -839,16 +825,13 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } if (window_size_right >= seqlen_q - 1) { window_size_right = -1; } // causal=true is the same as causal=false in this case - if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1) { + if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) { // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA if ((head_size <= 64 || head_size > 128) || !paged_KV) { is_causal = false; } } if (is_causal) { window_size_right = 0; } - // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_fprop will set params.is_causal=true. - // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM. - is_causal = window_size_left < 0 && window_size_right == 0; if (!is_varlen_q) { CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); @@ -960,6 +943,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq softmax_scale, window_size_left, window_size_right, + attention_chunk, softcap, sm_margin); params.total_q = total_q; @@ -1034,15 +1018,6 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; // printf("Num splits = %d.\n", params.num_splits); params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); - // Always enable PackGQA for Split - params.pack_gqa |= (params.num_splits > 1); - #ifdef FLASHATTENTION_PACKGQA_ONLY - params.pack_gqa |= params.d == params.dv; - #endif - - bool const use_dynamic_split = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA && params.num_splits > 1; - // disable split for varlen and >992 batches for now - if (use_prepare_varlen && params.b > PREPARE_VARLEN_MAX_BATCHES_1CTA) { params.num_splits = 1; } // This needs to be set after get_num_splits at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic @@ -1050,17 +1025,14 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq bool const scheduler_needs_semaphore = params.arch >= 90 ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); - // params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template - params.varlen_sort_batches = false; + params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template if (scheduler_needs_semaphore || use_prepare_varlen) { int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers - int num_prepare_batch_vectors = use_prepare_varlen ? 1 : 0; - if (use_dynamic_split) { num_prepare_batch_vectors += 1; } - if (params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } - if (params.head_swizzle) { num_prepare_batch_vectors += 1; } - int sort_offset = b_rounded * (use_dynamic_split ? 2 : 1); - int head_swizzle_offset = b_rounded * (num_prepare_batch_vectors - 1); + int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0; + if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } + if(params.head_swizzle) { num_prepare_batch_vectors += 1; } + int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2); int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; int metadata_size = int(scheduler_needs_semaphore) + tile_count_semaphore_offset; // printf("Num prepare batch vectors = %d, metadata_size = %d.\n", num_prepare_batch_vectors, metadata_size); @@ -1078,10 +1050,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq if (scheduler_needs_semaphore && !use_prepare_varlen) { tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing } - // ORDER: {prepare_seqlen_q, num_splits_dynamic, varlen_batch_idx, num_nheads_in_l2} - params.prepare_seqlen_q_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() : nullptr; - params.num_splits_dynamic_ptr = use_prepare_varlen && use_dynamic_split ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; - params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr() + sort_offset : nullptr; + // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} + params.num_splits_dynamic_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() : nullptr; + params.num_m_blocks_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; + params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr() + tile_count_semaphore_offset : nullptr; @@ -1090,7 +1062,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq if (q_v_.has_value()) { TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); - TORCH_CHECK(head_size_v >= 256, "q_v is only supported for hdim_v <= 256."); + TORCH_CHECK(head_size_v >= 256, "q_v is only supported for hdim_v >= 256."); TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, "q_v is only supported for fp16 and bf16 data type"); TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); @@ -1223,8 +1195,8 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.s_aux_ptr = nullptr; } - params.cp_world_size = cp_world_size; - params.cp_rank = cp_rank; + params.cp_world_size = static_cast(cp_world_size); + params.cp_rank = static_cast(cp_rank); params.cp_tot_seqused_k = cp_tot_seqused_k_.has_value() ? static_cast(cp_tot_seqused_k_.value().data_ptr()) : nullptr; TORCH_CHECK(cp_world_size > 0, "cp_world_size must be positive, required by downstream unified code path. Use 1 if CP is not enabled."); @@ -1286,8 +1258,53 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq return {out, softmax_lse, out_accum, softmax_lse_accum}; } +#ifdef FLASHATTENTION_DISABLE_BACKWARD +void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { + TORCH_CHECK(false, "Flash-Attention was built with backward disabled"); +} +#else +template +void run_mha_bwd_constexpr(Flash_bwd_params ¶ms, cudaStream_t stream) { + if (!params.is_bf16) { + #ifndef FLASHATTENTION_DISABLE_FP16 + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d_rounded == 64) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d_rounded == 96) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d_rounded == 128) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d_rounded == 192) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d_rounded == 256) { return run_mha_bwd_(params, stream); } + #endif + #else + TORCH_CHECK(false, "This flash attention build does not support FP16."); + #endif + } else { + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d_rounded == 64) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d_rounded == 96) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d_rounded == 128) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d_rounded == 192) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d_rounded == 256) { return run_mha_bwd_(params, stream); } + #endif + } +} + void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { - #ifndef FLASHATTENTION_DISABLE_BACKWARD // FP16_SWITCH(!params.is_bf16, [&] { // HEADDIM_SWITCH(params.d, [&] { // run_mha_bwd_(params, stream); @@ -1295,47 +1312,11 @@ void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { // }); ARCH_SWITCH(params.arch, Arch, [&] { SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] { - if (!params.is_bf16) { - #ifndef FLASHATTENTION_DISABLE_FP16 - #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_bwd_(params, stream); } - #endif - #else - TORCH_CHECK(false, "This flash attention build does not support FP16."); - #endif - } else { - #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_bwd_(params, stream); } - #endif - } + run_mha_bwd_constexpr(params, stream); }); }); - #endif } +#endif // b: batch_size @@ -1344,29 +1325,30 @@ void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { // h: num_heads // h_k: num_heads_k // d: head_size -std::vector mha_bwd( - const at::Tensor &dout, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor &k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - const at::Tensor &v, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - const at::Tensor &out, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor &softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q - std::optional &dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - std::optional &dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - std::optional &dv_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - std::optional &cu_seqlens_q_, // b+1 - std::optional &cu_seqlens_k_, // b+1 - std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional &seqused_k_, // b. If given, only this many elements of each batch element's keys are used. - std::optional max_seqlen_q_, - std::optional max_seqlen_k_, - float const softmax_scale, +std::tuple mha_bwd( + at::Tensor dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + at::Tensor k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + at::Tensor v, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k + at::Tensor out, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + at::Tensor softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q + std::optional dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + std::optional dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + std::optional dv_, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + std::optional max_seqlen_q_, + std::optional max_seqlen_k_, + std::optional softmax_scale_, bool is_causal, - int window_size_left, - int window_size_right, - float const softcap, - bool const deterministic, - int const sm_margin) { + int64_t window_size_left, + int64_t window_size_right, + double softcap, + bool deterministic, + int64_t sm_margin +) { #ifdef FLASHATTENTION_DISABLE_BACKWARD TORCH_CHECK(false, "This flash attention build does not support backward."); @@ -1421,13 +1403,19 @@ std::vector mha_bwd( int const total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0]; int const num_heads = q.size(-2); int const head_size = q.size(-1); + int const head_size_v = v.size(-1); int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value(); int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); int const num_heads_k = k.size(-2); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size_v % 8 == 0, "head_size_v should be a multiple of 8"); int const max_headdim = get_max_headdim(); - TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); + TORCH_CHECK(std::max(head_size, head_size_v) <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + double softmax_scale = 1.0 / sqrt(double(head_size)); + if (softmax_scale_.has_value()) { + softmax_scale = softmax_scale_.value(); + } // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } @@ -1438,7 +1426,9 @@ std::vector mha_bwd( is_causal = window_size_left < 0 && window_size_right == 0; int const arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; - int const head_size_rounded = round_up_headdim(head_size); + int const head_size_rounded = round_up_headdim(std::max(head_size, head_size_v)); + int const head_size_v_rounded = head_size_rounded; + TORCH_CHECK(!deterministic || head_size_rounded < 256, "Deterministic backward not supported for hdim 256."); // Very important that these match the kernel configs bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal; int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128) @@ -1467,20 +1457,20 @@ std::vector mha_bwd( if (!is_varlen_q) { CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); - CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); + CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_v); } else { CHECK_SHAPE(q, total_q, num_heads, head_size); - CHECK_SHAPE(out, total_q, num_heads, head_size); - CHECK_SHAPE(dout, total_q, num_heads, head_size); + CHECK_SHAPE(out, total_q, num_heads, head_size_v); + CHECK_SHAPE(dout, total_q, num_heads, head_size_v); CHECK_SHAPE(cu_seqlens_q, batch_size + 1); } if (!is_varlen_k) { CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_v); } else { CHECK_SHAPE(k, total_k, num_heads_k, head_size); - CHECK_SHAPE(v, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); } @@ -1530,9 +1520,9 @@ std::vector mha_bwd( CHECK_DEVICE(dv); TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); if (!is_varlen_k) { - CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_v); } else { - CHECK_SHAPE(dv, total_k, num_heads_k, head_size); + CHECK_SHAPE(dv, total_k, num_heads_k, head_size_v); } } else { dv = torch::empty_like(v); @@ -1562,10 +1552,10 @@ std::vector mha_bwd( if (num_heads_k != num_heads) { // MQA / GQA if (!is_varlen) { dk_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat)); - dv_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat)); + dv_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_v_rounded}, opts.dtype(at::kFloat)); } else { dk_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat)); - dv_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + dv_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_v_rounded}, opts.dtype(at::kFloat)); } } @@ -1591,23 +1581,26 @@ std::vector mha_bwd( softmax_scale, window_size_left, window_size_right, + 0, // attention_chunk softcap, deterministic, sm_margin); params.total_q = total_q; params.total_k = total_k; params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr(); - params.dv = head_size; // We don't support hdim_v being different from hdim_qk for now + params.dv = head_size_v; + params.dv_rounded = head_size_v_rounded; // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32)); // params.tile_count_semaphore = tile_count_semaphore.data_ptr(); // Will be zero'ed out in the backward preprocess kernel at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32)); params.dq_semaphore = dq_semaphore.data_ptr(); + at::Tensor dk_semaphore, dv_semaphore; if (num_heads_k != num_heads && params.deterministic) { - // TODO: do we need to zero them out? - at::Tensor dk_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); - at::Tensor dv_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); + // TODO: maybe also zero'ed out dk_semaphore and dv_semaphore in the backward preprocess kernel + dk_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); + dv_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); params.dk_semaphore = dk_semaphore.data_ptr(); params.dv_semaphore = dv_semaphore.data_ptr(); } @@ -1632,12 +1625,12 @@ std::vector mha_bwd( softmax_d.zero_(); } - return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum }; + return { softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum }; } -std::vector -mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x seqlen x num_heads x head_size - const at::Tensor &lse_partial, // num_splits x batch_size x seqlen x num_heads +std::tuple +mha_combine(at::Tensor out_partial, // num_splits x batch_size x seqlen x num_heads x head_size + at::Tensor lse_partial, // num_splits x batch_size x seqlen x num_heads std::optional out_, // batch_size x seqlen x num_heads x head_size std::optional out_dtype_ ) { @@ -1738,16 +1731,104 @@ mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x return {out, softmax_lse}; } -#ifndef FLASHATTENTION_DISABLE_PYBIND - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "FlashAttention"; - m.def("fwd", &mha_fwd, "Forward pass"); - m.def("bwd", &mha_bwd, "Backward pass"); - m.def("fwd_combine", &mha_combine, "Combine partial attention outputs"); - m.def("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata, "Get scheduler metadata for varlen forward pass"); +TORCH_LIBRARY(flash_attn_3, m) { + m.def("fwd(" + "Tensor q," + "Tensor k," + "Tensor v," + "Tensor(k_new!)? k_new = None," + "Tensor(v_new!)? v_new = None," + "Tensor? q_v = None," + "Tensor(out!)? out = None," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? cu_seqlens_k_new = None," + "Tensor? seqused_q = None," + "Tensor? seqused_k = None," + "int? max_seqlen_q = None," + "int? max_seqlen_k = None," + "Tensor? page_table = None," + "Tensor? kv_batch_idx = None," + "Tensor? leftpad_k = None," + "Tensor? rotary_cos = None," + "Tensor? rotary_sin = None," + "Tensor? seqlens_rotary = None," + "Tensor? q_descale = None," + "Tensor? k_descale = None," + "Tensor? v_descale = None," + "float? softmax_scale = None," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + "int attention_chunk = 0," + "float softcap = 0.0," + "bool is_rotary_interleaved = False," + "Tensor? scheduler_metadata = None," + "int num_splits = 0," + "bool? pack_gqa = None," + "int sm_margin = 0," + "Tensor? s_aux = None," + "int cp_world_size = 1," + "int cp_rank = 0," + "Tensor? cp_tot_seqused_k = None) -> (Tensor(out!), Tensor, Tensor, Tensor)"); + m.def("bwd(" + "Tensor dout," + "Tensor q," + "Tensor k," + "Tensor v," + "Tensor out," + "Tensor softmax_lse," + "Tensor(dq!)? dq = None," + "Tensor(dk!)? dk = None," + "Tensor(dv!)? dv = None," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? seqused_q = None," + "Tensor? seqused_k = None," + "int? max_seqlen_q = None," + "int? max_seqlen_k = None," + "float? softmax_scale = None," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + "float softcap = 0.0," + "bool deterministic = False," + "int sm_margin = 0) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"); + m.def("fwd_combine(" + "Tensor out_partial," + "Tensor lse_partial," + "Tensor(out!)? out = None," + "ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)"); + m.def("get_scheduler_metadata(" + "int batch_size," + "int max_seqlen_q," + "int max_seqlen_k," + "int num_heads," + "int num_heads_k," + "int headdim," + "int headdim_v," + "ScalarType qkv_dtype," + "Tensor seqused_k," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? cu_seqlens_k_new = None," + "Tensor? seqused_q = None," + "Tensor? leftpad_k = None," + "int? page_size = None," + "int max_seqlen_k_new = 0," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + "int attention_chunk = 0," + "bool has_softcap = False," + "int num_splits = 0," + "bool? pack_gqa = None," + "int sm_margin = 0) -> Tensor"); } -#endif +TORCH_LIBRARY_IMPL(flash_attn_3, CUDA, m) { + m.impl("fwd", &mha_fwd); + m.impl("bwd", &mha_bwd); + m.impl("fwd_combine", &mha_combine); + m.impl("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata); +} diff --git a/hopper/flash_api_stable.cpp b/hopper/flash_api_stable.cpp new file mode 100644 index 00000000000..9759af86e08 --- /dev/null +++ b/hopper/flash_api_stable.cpp @@ -0,0 +1,1988 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#include + +#include + +#include "flash.h" +#include "static_switch.h" +#include "tile_size.h" +#include "heuristics.h" +#include "cuda_check.h" + +#include +#include +#include +#include +#include + +// Declare the CUDA stream function that's behind #ifdef USE_CUDA in shim.h +extern "C" AOTITorchError aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream); + +#include +#include + +#include +#include +#include +#include + +using torch::stable::Tensor; +namespace tsa = torch::stable::accelerator; + +namespace { +inline tsa::DeviceGuard make_device_guard(const Tensor& t) { + return tsa::DeviceGuard(static_cast(t.get_device())); +} +std::deque device_flags; +std::vector device_properties; + +void initVectors() { + static bool init_flag [[maybe_unused]] = []() { + int device_count; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess) { + STD_TORCH_CHECK(false, "cudaGetDeviceProperties failed: " + + std::string(cudaGetErrorString(err))); + } + device_flags.resize(device_count); + device_properties.resize(device_count); + return true; + }(); +} + +void initDeviceProperty(int device_index) { + cudaDeviceProp device_prop{}; + cudaError_t err = cudaGetDeviceProperties(&device_prop, device_index); + if (err != cudaSuccess) { + STD_TORCH_CHECK(false, "cudaGetDeviceProperties failed: " + + std::string(cudaGetErrorString(err))); + } + device_properties[device_index] = device_prop; +} + +// Helper function to get device properties using raw CUDA APIs +cudaDeviceProp* get_device_prop() { + initVectors(); + int device_index; + cudaError_t err = cudaGetDevice(&device_index); + if (err != cudaSuccess) { + STD_TORCH_CHECK(false, "cudaGetDevice failed: " + + std::string(cudaGetErrorString(err))); + } + + std::call_once(device_flags[device_index], initDeviceProperty, device_index); + return &device_properties[device_index]; +} +} // anonymous namespace + + +extern "C" { +/* Creates a dummy empty _C module that can be imported from Python. + The import from Python will load the .so consisting of this file + in this extension, so that the STABLE_TORCH_LIBRARY static initializers + below are run. */ +PyObject* PyInit__C(void) +{ + static struct PyModuleDef module_def = { + PyModuleDef_HEAD_INIT, + "_C", /* name of module */ + NULL, /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + NULL, /* methods */ + }; + return PyModule_Create(&module_def); +} +} + +#define CHECK_DEVICE(x) STD_TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) \ + do { \ + auto expected_dims = std::vector{__VA_ARGS__}; \ + STD_TORCH_CHECK(x.dim() == static_cast(expected_dims.size()), #x " must have " + std::to_string(expected_dims.size()) + " dimensions, got " + std::to_string(x.dim())); \ + for (size_t i = 0; i < expected_dims.size(); ++i) { \ + STD_TORCH_CHECK(x.size(i) == expected_dims[i], #x " dimension " + std::to_string(i) + " must have size " + std::to_string(expected_dims[i]) + ", got " + std::to_string(x.size(i))); \ + } \ + } while (0) +#define CHECK_CONTIGUOUS(x) STD_TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +#define PREPARE_VARLEN_MAX_BATCHES_1CTA 992 + +void set_params_fprop(Flash_fwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + const Tensor q, + const Tensor k, + const Tensor v, + Tensor out, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_q, + void *seqused_k, + void *softmax_lse_d, + float p_dropout, + float softmax_scale, + int window_size_left, + int window_size_right, + int attention_chunk, + const float softcap=0.f, + const int sm_margin=0) { + + // Reset the parameters + params = {}; + + params.is_bf16 = q.scalar_type() == torch::headeronly::ScalarType::BFloat16; + params.is_e4m3 = q.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn; + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + // All stride are in elements, not bytes. + params.q_row_stride = q.stride(-3); + params.k_row_stride = k.stride(-3); + params.v_row_stride = v.stride(-3); + params.q_head_stride = q.stride(-2); + params.k_head_stride = k.stride(-2); + params.v_head_stride = v.stride(-2); + params.v_dim_stride = v.stride(-1); + params.o_ptr = out.data_ptr(); + params.o_row_stride = out.stride(-3); + params.o_head_stride = out.stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = q.stride(0); + params.o_batch_stride = out.stride(0); + } + if (cu_seqlens_k_d == nullptr) { + params.k_batch_stride = k.stride(0); + params.v_batch_stride = v.stride(0); + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_q = static_cast(seqused_q); + params.seqused_k = static_cast(seqused_k); + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = b; + params.h = h; + params.h_k = h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.softcap = softcap; + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f - p_dropout; + // Convert p from float to int so we don't have to convert the random uint to float to compare. + // [Minor] We want to round down since when we do the comparison we use <= instead of < + // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); + // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + STD_TORCH_CHECK(p_dropout < 1.f); + #ifdef FLASHATTENTION_DISABLE_DROPOUT + STD_TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); + #endif + + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0; + params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal; + + // TODO: check this + if (window_size_left < 0) { window_size_left = seqlen_k - 1; } + if (window_size_right < 0) { window_size_right = seqlen_q - 1; } + if (attention_chunk > 0) { + window_size_left = std::min(window_size_left, attention_chunk - 1); + window_size_right = std::min(window_size_right, attention_chunk - 1); + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.attention_chunk = attention_chunk; + + auto dprops = get_device_prop(); + params.arch = dprops->major * 10 + dprops->minor; + params.num_sm = dprops->multiProcessorCount - sm_margin; + + #ifdef FLASHATTENTION_DISABLE_LOCAL + STD_TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); + #endif +} + +void set_params_dgrad(Flash_bwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + const Tensor q, + const Tensor k, + const Tensor v, + const Tensor out, + const Tensor dout, + Tensor dq, + Tensor dk, + Tensor dv, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_q, + void *seqused_k, + void *dq_accum_d, + void *dk_accum_d, + void *dv_accum_d, + void *softmax_lse_d, + void *dsoftmax_sum_d, + float p_dropout, + float softmax_scale, + int window_size_left, + int window_size_right, + int attention_chunk, + const float softcap=0.f, + bool deterministic=false, + int const sm_margin=0) { + + set_params_fprop(params, + b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, + q, k, v, out, + cu_seqlens_q_d, + cu_seqlens_k_d, + seqused_q, + seqused_k, + softmax_lse_d, + p_dropout, + softmax_scale, + window_size_left, + window_size_right, + attention_chunk, + softcap, + sm_margin); + + // Set the pointers and strides. + params.do_ptr = dout.data_ptr(); + params.do_row_stride = dout.stride(-3); + params.do_head_stride = dout.stride(-2); + params.dq_ptr = dq.data_ptr(); + params.dk_ptr = dk.data_ptr(); + params.dv_ptr = dv.data_ptr(); + params.dq_row_stride = dq.stride(-3); + params.dk_row_stride = dk.stride(-3); + params.dv_row_stride = dv.stride(-3); + params.dq_head_stride = dq.stride(-2); + params.dk_head_stride = dk.stride(-2); + params.dv_head_stride = dv.stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.do_batch_stride = dout.stride(0); + params.dq_batch_stride = dq.stride(0); + params.dk_batch_stride = dk.stride(0); + params.dv_batch_stride = dv.stride(0); + } + + params.dq_accum_ptr = dq_accum_d; + params.dk_accum_ptr = dk_accum_d; + params.dv_accum_ptr = dv_accum_d; + + // Softmax sum + params.dsoftmax_sum = dsoftmax_sum_d; + + params.deterministic = deterministic; +} + +template +void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { + if (!params.is_e4m3) { + if (params.is_bf16) { + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d <= 64) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 + if constexpr (Arch == 90) { + if (params.dv > 256) { + return run_mha_fwd_(params, stream); + } else if (params.dv > 64) { + return run_mha_fwd_(params, stream); + } + } + #endif + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 + if constexpr (Arch == 90) { + if (params.dv <= 128) { + return run_mha_fwd_(params, stream); + } + } + #endif + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_fwd_(params, stream); } + #endif + } else { + #ifndef FLASHATTENTION_DISABLE_FP16 + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d <= 64) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 + if constexpr (Arch == 90) { + if (params.dv > 256) { + return run_mha_fwd_(params, stream); + } else if (params.dv > 64) { + return run_mha_fwd_(params, stream); + } + } + #endif + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 + if constexpr (Arch == 90) { + if (params.dv <= 128) { + return run_mha_fwd_(params, stream); + } + } + #endif + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_fwd_(params, stream); } + #endif + #else + STD_TORCH_CHECK(false, "This flash attention build does not support FP16."); + #endif + } + } else { + #ifndef FLASHATTENTION_DISABLE_FP8 + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 + if constexpr (Arch == 90) { + if (params.dv <= 128) { + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); + } + } + #endif + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #else + STD_TORCH_CHECK(false, "This flash attention build does not support FP8."); + #endif + } +} + +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + // HEADDIM_SWITCH(params.d, [&] { + // run_mha_fwd_(params, stream); + // }); + STD_TORCH_CHECK(params.num_splits >= 1); + ARCH_SWITCH(params.arch, Arch, [&] { + SPLIT_SWITCH(params.num_splits > 1, Split, [&] { + PAGEDKV_SWITCH(params.page_table && !params.pagedkv_tma, PagedKVNonTMA, [&] { + PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] { + // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation + static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKVNonTMA || Split; + SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] { + run_mha_fwd_constexpr(params, stream); + }); + }); + }); + }); + }); +} + +void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl=false) { + #ifndef FLASHATTENTION_DISABLE_SPLIT + // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively + // so that kBlockM is smaller and we have more parallelism. + if (params.is_fp32) { + if (params.dv <= 64) { + run_mha_fwd_combine_(params, stream, enable_pdl); + } else { + run_mha_fwd_combine_(params, stream, enable_pdl); + } + } else if (params.is_bf16) { + if (params.dv <= 64) { + run_mha_fwd_combine_(params, stream, enable_pdl); + } else { + run_mha_fwd_combine_(params, stream, enable_pdl); + } + } else { + if (params.dv <= 64) { + run_mha_fwd_combine_(params, stream, enable_pdl); + } else { + run_mha_fwd_combine_(params, stream, enable_pdl); + } + } + #else + STD_TORCH_CHECK(false, "This flash attention build does not support combine kernels."); + #endif +} + +inline bool get_pagedkv_tma(Flash_fwd_params const& params) { + if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; } + // This needs to match the kernel configs + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f); + int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); + int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90); + // Heuristic: when seqlen_q <= kBlockM, we're not compute bound, and somehow using TMA is slower, + // at least for MLA. + return params.page_size % kBlockN == 0 && params.seqlen_q * (params.h / params.h_k) > kBlockM; +} + +inline bool get_pack_gqa(Flash_fwd_params const& params) { + // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size. + // Has little effect on speed. + if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; } + #ifdef FLASHATTENTION_DISABLE_PACKGQA + return false; + #else + // params.page_table must already be set + if (params.h == params.h_k) { return false; } + // This needs to match the kernel configs + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); + int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); + return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM); + #endif +} + +inline int get_num_splits(Flash_fwd_params const& params) { + #ifdef FLASHATTENTION_DISABLE_SPLIT + return 1; + #else + // Always enable PackGQA for Split + // params.page_table must already be set + // This needs to match the kernel configs + bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k; + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); + // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits + // has not been set here. It's OK though because we might just underestimate kBlockN a bit + auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); + int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); + int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); + int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k); + // If is_local, we're not going to load all of seqlen_k + int const seqlen_k_loaded = !params.is_local + ? params.seqlen_k + : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM)); + int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN; + int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM; + int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2); + // Always enable PackGQA for Split + // If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits. + // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending + // that batch = 1. + int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks; + return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128); + #endif +} + +inline int get_max_headdim() { + #ifndef FLASHATTENTION_DISABLE_HDIM256 + return 256; + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + return 192; + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + return 128; + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + return 96; + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM64 + return 64; + #endif + return 0; +} + +inline int round_up_headdim(int head_size) { + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (head_size <= 64) { return 64; } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (head_size <= 96) { return 96; } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (head_size <= 128) { return 128; } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (head_size <= 192) { return 192; } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (head_size <= 256) { return 256; } + #endif + return 256; +} + +inline int round_up_headdimv(int head_size) { + if (head_size <= 64) { return 64; } + if (head_size <= 96) { return 96; } + if (head_size <= 128) { return 128; } + if (head_size <= 192) { return 192; } + if (head_size <= 256) { return 256; } + return 512; +} + +// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available +Tensor +mha_fwd_get_scheduler_metadata( + int64_t batch_size, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + int64_t num_heads, + int64_t num_heads_k, + int64_t headdim, + int64_t headdim_v, + torch::headeronly::ScalarType qkv_dtype, + Tensor seqused_k, // b + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional cu_seqlens_k_new_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional leftpad_k_, // b + std::optional page_size, + int64_t max_seqlen_k_new, // 0 means we're not appending new KV + bool is_causal, + int64_t window_size_left, + int64_t window_size_right, + int64_t attention_chunk, + bool has_softcap, + int64_t num_splits, + std::optional pack_gqa_, + int64_t sm_margin) { + + STD_TORCH_CHECK(qkv_dtype == torch::headeronly::ScalarType::Half || qkv_dtype == torch::headeronly::ScalarType::BFloat16 || qkv_dtype == torch::headeronly::ScalarType::Float8_e4m3fn, + "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); + STD_TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + // Reset the parameters + Flash_fwd_params params{}; + params.is_bf16 = qkv_dtype == torch::headeronly::ScalarType::BFloat16; + params.is_e4m3 = qkv_dtype == torch::headeronly::ScalarType::Float8_e4m3fn; + params.b = batch_size; + params.seqlen_q = max_seqlen_q; + params.seqlen_k = max_seqlen_k; + params.h = num_heads; + params.h_k = num_heads_k; + params.d = headdim; + params.dv = headdim_v; + params.d_rounded = round_up_headdim(headdim); + params.dv_rounded = headdim_v == headdim ? params.d_rounded : round_up_headdimv(headdim_v); + params.seqlen_knew = max_seqlen_k_new; + + bool const is_varlen_q = cu_seqlens_q_.has_value(); + params.cu_seqlens_q = is_varlen_q ? static_cast(cu_seqlens_q_.value().data_ptr()) : nullptr; + bool const is_varlen_k = cu_seqlens_k_.has_value(); + params.cu_seqlens_k = is_varlen_k ? static_cast(cu_seqlens_k_.value().data_ptr()) : nullptr; + params.cu_seqlens_knew = cu_seqlens_k_new_.has_value() ? static_cast(cu_seqlens_k_new_.value().data_ptr()): nullptr; + params.seqused_q = seqused_q_.has_value() ? static_cast(seqused_q_.value().data_ptr()) : nullptr; + params.seqused_k = static_cast(seqused_k.data_ptr()); + params.leftpad_k = leftpad_k_.has_value() ? static_cast(leftpad_k_.value().data_ptr()) : nullptr; + params.knew_ptr = params.seqlen_knew > 0 ? reinterpret_cast(1) : nullptr; + if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; } + if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; } + // causal=true is the same as causal=false in this case + if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) { + // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA + if ((headdim <= 64 || headdim > 128) || !page_size.has_value()) { + is_causal = false; + } + } + if (is_causal) { window_size_right = 0; } + + params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0; + params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal; + if (window_size_left < 0) { window_size_left = max_seqlen_k - 1; } + if (window_size_right < 0) { window_size_right = max_seqlen_q - 1; } + if (attention_chunk > 0) { + window_size_left = std::min(window_size_left, attention_chunk - 1); + window_size_right = std::min(window_size_right, attention_chunk - 1); + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.attention_chunk = attention_chunk; + auto dprops = get_device_prop(); + params.arch = dprops->major * 10 + dprops->minor; + params.num_sm = dprops->multiProcessorCount - sm_margin; + params.softcap = has_softcap ? 1.0f : 0.0f; + + params.page_size = page_size.has_value() ? page_size.value() : 1; + params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast(1); + + bool const use_prepare_varlen = true; + params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA; + params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); + + params.pagedkv_tma = get_pagedkv_tma(params); + params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide + params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); + + bool is_varlen = true; + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + auto device_guard = make_device_guard(seqused_k); + + // This needs to be set after get_num_splits + Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic + bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template + params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template + if (scheduler_needs_semaphore || use_prepare_varlen) { + int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers + int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0; + if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } + if(params.head_swizzle) { num_prepare_batch_vectors += 1; } + int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2); + int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; + // printf("(Metadata) num prepare batch vectors = %d.\n", num_prepare_batch_vectors); + tile_count_semaphore = torch::stable::new_empty( + seqused_k, + {int(scheduler_needs_semaphore) + tile_count_semaphore_offset}, + std::make_optional(torch::headeronly::ScalarType::Int)); + // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} + params.num_splits_dynamic_ptr = use_prepare_varlen ? static_cast(tile_count_semaphore.data_ptr()) : nullptr; + params.num_m_blocks_ptr = use_prepare_varlen ? static_cast(tile_count_semaphore.data_ptr()) + b_rounded : nullptr; + params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? static_cast(tile_count_semaphore.data_ptr()) + b_rounded * 2 : nullptr; + // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? static_cast(tile_count_semaphore.data_ptr()) + head_swizzle_offset : nullptr; + params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? static_cast(tile_count_semaphore.data_ptr()) + head_swizzle_offset : nullptr; + if (scheduler_needs_semaphore) { + if (!use_prepare_varlen) { torch::stable::zero_(tile_count_semaphore); } // If varlen we'll manually do the zero-ing + params.tile_count_semaphore = static_cast(tile_count_semaphore.data_ptr()) + tile_count_semaphore_offset; + } else { + params.tile_count_semaphore = nullptr; + } + } + + if (use_prepare_varlen) { + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr); + int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); + int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); + auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_idx, &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); + prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/); + CHECK_CUDA_KERNEL_LAUNCH(); + } + return tile_count_semaphore; +} + +// b: batch_size +// b_k: batch_size_k +// s_q: seqlen_q +// s_k: seqlen_k +// s_k_new: seqlen_k_new +// h: num_heads +// h_k: num_heads_k +// d: head_size +std::tuple +mha_fwd(Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. + Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. + std::optional k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new + std::optional v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + std::optional out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional cu_seqlens_k_new_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + std::optional max_seqlen_q_, + // TODO: check if we need max_seqlen_k + std::optional max_seqlen_k_, + std::optional page_table_, // (b_k, max_num_pages_per_seq) + std::optional kv_batch_idx_, // b. indices to index into the KV cache + std::optional leftpad_k_, // b + std::optional rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional seqlens_rotary_, // b + std::optional q_descale_, // (b, h_k), not (b, h) + std::optional k_descale_, // (b, h_k) + std::optional v_descale_, // (b, h_k) + std::optional softmax_scale_, + bool is_causal, + int64_t window_size_left, + int64_t window_size_right, + int64_t attention_chunk, + double softcap, + bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + std::optional scheduler_metadata_, // (b + 1) + int64_t num_splits, + std::optional pack_gqa_, + int64_t sm_margin + ) { + + auto dprops = get_device_prop(); + bool is_sm8x = dprops->major >= 8; + STD_TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + + auto q_type = q.scalar_type(); + STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16 || q_type == torch::headeronly::ScalarType::Float8_e4m3fn, + "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); + if (dprops->major < 9) { + STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16, + "FlashAttention on Ampere/Ada cards only supports fp16 and bf16 data type"); + } + STD_TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); + STD_TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + + STD_TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + Tensor page_table; + const bool paged_KV = page_table_.has_value(); + if (paged_KV) { + page_table = page_table_.value(); + CHECK_DEVICE(page_table); + STD_TORCH_CHECK(page_table.scalar_type() == torch::headeronly::ScalarType::Int, "page_table must have dtype torch.int32"); + STD_TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension"); + } + + Tensor cu_seqlens_q; + bool const is_varlen_q = cu_seqlens_q_.has_value(); + if (is_varlen_q) { + cu_seqlens_q = cu_seqlens_q_.value(); + CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q); + STD_TORCH_CHECK(cu_seqlens_q.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_q must have dtype torch.int32"); + STD_TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided"); + } + Tensor cu_seqlens_k; + bool const is_varlen_k = cu_seqlens_k_.has_value(); + if (is_varlen_k) { + cu_seqlens_k = cu_seqlens_k_.value(); + CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k); + STD_TORCH_CHECK(cu_seqlens_k.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_k must have dtype torch.int32"); + STD_TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided"); + STD_TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported"); + STD_TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported"); + } + + const int batch_size = !is_varlen_q ? q.size(0) : cu_seqlens_q.size(0) - 1; + int seqlen_q = !is_varlen_q ? q.size(1) : max_seqlen_q_.value(); + int total_q = !is_varlen_q ? batch_size * q.size(1) : q.size(0); + int num_heads = q.size(-2); + int const head_size = q.size(-1); + int const head_size_v = v.size(-1); + int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1); + int const num_pages = !paged_KV ? 0 : k.size(0); + int const page_size = !paged_KV ? 1 : k.size(1); + int const seqlen_k = !is_varlen_k ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value(); + int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); + int const num_heads_k = k.size(-2); + int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0); + double softmax_scale = 1.0 / sqrt(double(head_size)); + if (softmax_scale_.has_value()) { + softmax_scale = softmax_scale_.value(); + } + if (!kv_batch_idx_.has_value()) { + STD_TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); + } + int const max_headdim = get_max_headdim(); + STD_TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); + STD_TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (head_size_v != head_size) { + STD_TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) || + (head_size <= 64 && head_size_v <= 512), + "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], " + "or (Q/K <= 64 and V <= 512)."); + STD_TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim"); + if (head_size_v > 256) { + STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16, + "HeaddimV > 256 requires fp16 and bf16 data type"); + } + } + + // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM + // TODO: check this + if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } + if (window_size_right >= seqlen_q - 1) { window_size_right = -1; } + // causal=true is the same as causal=false in this case + if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) { + // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA + if ((head_size <= 64 || head_size > 128) || !paged_KV) { + is_causal = false; + } + } + if (is_causal) { window_size_right = 0; } + + if (!is_varlen_q) { + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + } else { + CHECK_SHAPE(q, total_q, num_heads, head_size); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + } + if (!paged_KV) { + if (!is_varlen_k) { + CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v); + } else { + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + } + } else { + CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); + CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); + CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq); + } + + if (seqused_q_.has_value()){ + auto seqused_q = seqused_q_.value(); + STD_TORCH_CHECK(seqused_q.scalar_type() == torch::headeronly::ScalarType::Int, "seqused_q must have dtype int32"); + CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q); + CHECK_SHAPE(seqused_q, batch_size); + } + if (seqused_k_.has_value()) { + auto seqused_k = seqused_k_.value(); + STD_TORCH_CHECK(seqused_k.scalar_type() == torch::headeronly::ScalarType::Int, "seqused_k must have dtype int32"); + CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k); + CHECK_SHAPE(seqused_k, batch_size); + } + + if (leftpad_k_.has_value()) { + auto leftpad_k = leftpad_k_.value(); + STD_TORCH_CHECK(leftpad_k.scalar_type() == torch::headeronly::ScalarType::Int, "leftpad_k must have dtype int32"); + CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + } + + // This is what we will template on + bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value(); + #ifdef FLASHATTENTION_DISABLE_VARLEN + STD_TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); + #endif + + int const alignment = q_type == torch::headeronly::ScalarType::Float8_e4m3fn ? 16 : 8; + STD_TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); + STD_TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); + + auto out_type = q_type == torch::headeronly::ScalarType::Float8_e4m3fn ? torch::headeronly::ScalarType::BFloat16 : q_type; + Tensor out; + if (out_.has_value()) { + out = out_.value(); + STD_TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16"); + CHECK_DEVICE(out); + STD_TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + if (!is_varlen_q) { + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); + } else { + CHECK_SHAPE(out, total_q, num_heads, head_size_v); + } + } else { + out = !is_varlen_q + ? torch::stable::new_empty(q, {batch_size, seqlen_q, num_heads, head_size_v}, std::make_optional(out_type)) + : torch::stable::new_empty(q, {total_q, num_heads, head_size_v}, std::make_optional(out_type)); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + int const head_size_rounded = round_up_headdim(head_size); + int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdimv(head_size_v); + int const seqlen_q_rounded = round_multiple(seqlen_q, 128); + int const seqlen_k_rounded = round_multiple(seqlen_k, 128); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + auto device_guard = make_device_guard(q); + + Tensor softmax_lse; + if (!is_varlen_q) { + softmax_lse = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q}, std::make_optional(torch::headeronly::ScalarType::Float)); + } else { + softmax_lse = torch::stable::new_empty(q, {num_heads, total_q}, std::make_optional(torch::headeronly::ScalarType::Float)); + } + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(), + !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(), + seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr, + seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr, + softmax_lse.data_ptr(), + /*p_dropout=*/0.f, + softmax_scale, + window_size_left, + window_size_right, + attention_chunk, + softcap, + sm_margin); + params.total_q = total_q; + params.total_k = total_k; + params.b_k = batch_size_k; + params.dv = head_size_v; + params.dv_rounded = head_size_v_rounded; + if (leftpad_k_.has_value()) { // This needs to be set before get_pagedkv_tma + params.leftpad_k = static_cast(leftpad_k_.value().data_ptr()); + } + if (paged_KV) { + params.page_table = static_cast(page_table.data_ptr()); + params.page_table_batch_stride = page_table.stride(0); + } + params.page_size = page_size; + params.num_pages = num_pages; + + if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma + Tensor k_new, v_new; + STD_TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in"); + STD_TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in"); + STD_TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache"); + Tensor cu_seqlens_k_new; + bool const is_varlen_k_new = cu_seqlens_k_new_.has_value(); + if (is_varlen_k_new) { + cu_seqlens_k_new = cu_seqlens_k_new_.value(); + CHECK_DEVICE(cu_seqlens_k_new); CHECK_CONTIGUOUS(cu_seqlens_k_new); + STD_TORCH_CHECK(cu_seqlens_k_new.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_k_new must have dtype torch.int32"); + } + k_new = k_new_.value(); + v_new = v_new_.value(); + STD_TORCH_CHECK(k_new.scalar_type() == q_type, "k_new must have the same dtype as query"); + STD_TORCH_CHECK(v_new.scalar_type() == q_type, "v_new must have the same dtype as query"); + CHECK_DEVICE(k_new); CHECK_DEVICE(v_new); + STD_TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension"); + STD_TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension"); + // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new + int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0; + int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0); + if (!is_varlen_k_new) { + CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size); + CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v); + } else { + CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size); + CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v); + CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1); + } + params.seqlen_knew = seqlen_k_new; + params.total_knew = total_k_new; + params.knew_ptr = k_new.data_ptr(); + params.vnew_ptr = v_new.data_ptr(); + // All stride are in elements, not bytes. + params.knew_row_stride = k_new.stride(-3); + params.vnew_row_stride = v_new.stride(-3); + params.knew_head_stride = k_new.stride(-2); + params.vnew_head_stride = v_new.stride(-2); + if (!is_varlen_k_new) { + params.knew_batch_stride = k_new.stride(0); + params.vnew_batch_stride = v_new.stride(0); + } + if (is_varlen_k_new) { + params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); + } + } + + bool const use_prepare_varlen = is_varlen; + params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA; + // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it + params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); + + params.pagedkv_tma = get_pagedkv_tma(params); + params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide + params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); + + // This needs to be set after get_num_splits + Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic + // We don't use the persistent scheduler if Split and not Varlen + bool const scheduler_needs_semaphore = params.arch >= 90 + ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) + : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); + params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template + params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template + if (scheduler_needs_semaphore || use_prepare_varlen) { + int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers + int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0; + if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } + if(params.head_swizzle) { num_prepare_batch_vectors += 1; } + int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2); + int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; + int metadata_size = int(scheduler_needs_semaphore) + tile_count_semaphore_offset; + // printf("Num prepare batch vectors = %d, metadata_size = %d.\n", num_prepare_batch_vectors, metadata_size); + params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value(); + if (scheduler_metadata_.has_value()) { + Tensor scheduler_metadata = scheduler_metadata_.value(); + CHECK_DEVICE(scheduler_metadata); + CHECK_SHAPE(scheduler_metadata, metadata_size); + CHECK_CONTIGUOUS(scheduler_metadata); + STD_TORCH_CHECK(scheduler_metadata.scalar_type() == torch::headeronly::ScalarType::Int, "scheduler_metadata must have dtype int32"); + tile_count_semaphore = scheduler_metadata; + } else { + tile_count_semaphore = torch::stable::new_empty(q, {metadata_size}, torch::headeronly::ScalarType::Int); + } + if (scheduler_needs_semaphore && !use_prepare_varlen) { + torch::stable::zero_(tile_count_semaphore); // If varlen we'll manually do the zero-ing + } + // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} + params.num_splits_dynamic_ptr = use_prepare_varlen ? static_cast(tile_count_semaphore.data_ptr()) : nullptr; + params.num_m_blocks_ptr = use_prepare_varlen ? static_cast(tile_count_semaphore.data_ptr()) + b_rounded : nullptr; + params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? static_cast(tile_count_semaphore.data_ptr()) + b_rounded * 2 : nullptr; + // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? static_cast(tile_count_semaphore.data_ptr()) + head_swizzle_offset : nullptr; + params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? static_cast(tile_count_semaphore.data_ptr()) + head_swizzle_offset : nullptr; + params.tile_count_semaphore = scheduler_needs_semaphore ? static_cast(tile_count_semaphore.data_ptr()) + tile_count_semaphore_offset : nullptr; + params.tile_count_semaphore_offset = tile_count_semaphore_offset; // might need to zero out semaphore later + } + + if (q_v_.has_value()) { + STD_TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); + STD_TORCH_CHECK(head_size_v >= 256, "q_v is only supported for hdim_v >= 256."); + STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16, + "q_v is only supported for fp16 and bf16 data type"); + STD_TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); + Tensor q_v = q_v_.value(); + STD_TORCH_CHECK(q_v.scalar_type() == q_type, "q_v must have the same dtype as query"); + CHECK_DEVICE(q_v); + STD_TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); + if (!is_varlen_q) { + CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v); + } else { + CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); + } + params.qv_ptr = q_v.data_ptr(); + // All stride are in elements, not bytes. + params.qv_row_stride = q_v.stride(-3); + params.qv_head_stride = q_v.stride(-2); + if (!is_varlen_q) { + params.qv_batch_stride = q_v.stride(0); + } + } + + if (rotary_cos_.has_value()) { + STD_TORCH_CHECK(k_new_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); + auto rotary_cos = rotary_cos_.value(); + CHECK_DEVICE(rotary_cos); CHECK_CONTIGUOUS(rotary_cos); + params.rotary_dim = rotary_cos.size(1) * 2; + STD_TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); + STD_TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); + const int seqlen_ro = rotary_cos.size(0); + if (paged_KV) { + STD_TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); + } + CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); + STD_TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); + + STD_TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); + auto rotary_sin = rotary_sin_.value(); + CHECK_DEVICE(rotary_sin); CHECK_CONTIGUOUS(rotary_sin); + CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); + STD_TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); + params.rotary_cos_ptr = rotary_cos.data_ptr(); + params.rotary_sin_ptr = rotary_sin.data_ptr(); + params.is_rotary_interleaved = is_rotary_interleaved; + if (seqlens_rotary_.has_value()) { + Tensor seqlens_rotary = seqlens_rotary_.value(); + CHECK_DEVICE(seqlens_rotary); CHECK_CONTIGUOUS(seqlens_rotary); + STD_TORCH_CHECK(seqlens_rotary.scalar_type() == torch::headeronly::ScalarType::Int, "seqlens_rotary must have dtype torch.int32"); + CHECK_SHAPE(seqlens_rotary, batch_size); + params.seqlens_rotary = static_cast(seqlens_rotary.data_ptr()); + } + } else { + params.rotary_dim = 0; + } + + if (kv_batch_idx_.has_value()) { + auto kv_batch_idx = kv_batch_idx_.value(); + CHECK_DEVICE(kv_batch_idx); CHECK_CONTIGUOUS(kv_batch_idx); + STD_TORCH_CHECK(kv_batch_idx.scalar_type() == torch::headeronly::ScalarType::Int, "kv_batch_idx must have dtype int32"); + params.kv_batch_idx = reinterpret_cast(kv_batch_idx.data_ptr()); + } + + Tensor out_accum, softmax_lse_accum; + auto outaccum_type = torch::headeronly::ScalarType::Float; + if (params.num_splits > 1) { + STD_TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported"); + if (!is_varlen_q) { + out_accum = torch::stable::new_empty(q, {params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, std::make_optional(outaccum_type)); + softmax_lse_accum = torch::stable::new_empty(q, {params.num_splits, batch_size, num_heads, seqlen_q}, std::make_optional(torch::headeronly::ScalarType::Float)); + params.oaccum_batch_stride = out_accum.stride(1); + params.lseaccum_batch_stride = softmax_lse_accum.stride(1); + } else { + out_accum = torch::stable::new_empty(q, {params.num_splits, num_heads, total_q, head_size_v}, std::make_optional(outaccum_type)); + softmax_lse_accum = torch::stable::new_empty(q, {params.num_splits, num_heads, total_q}, std::make_optional(torch::headeronly::ScalarType::Float)); + } + params.is_fp32 = false; + params.oaccum_ptr = out_accum.data_ptr(); + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_split_stride = out_accum.stride(0); + params.oaccum_row_stride = out_accum.stride(-2); + params.oaccum_head_stride = out_accum.stride(-3); + params.lseaccum_split_stride = softmax_lse_accum.stride(0); + params.lseaccum_head_stride = softmax_lse_accum.stride(-2); + } + + if (q_type == torch::headeronly::ScalarType::Float8_e4m3fn) { + if (q_descale_.has_value()) { + auto q_descale = q_descale_.value(); + CHECK_DEVICE(q_descale); + CHECK_SHAPE(q_descale, batch_size, num_heads_k); + params.q_descale_ptr = static_cast(q_descale.data_ptr()); + params.q_descale_batch_stride = q_descale.stride(0); + params.q_descale_head_stride = q_descale.stride(1); + } else { + params.q_descale_ptr = nullptr; + } + if (k_descale_.has_value()) { + auto k_descale = k_descale_.value(); + CHECK_DEVICE(k_descale); + CHECK_SHAPE(k_descale, batch_size, num_heads_k); + params.k_descale_ptr = static_cast(k_descale.data_ptr()); + params.k_descale_batch_stride = k_descale.stride(0); + params.k_descale_head_stride = k_descale.stride(1); + } else { + params.k_descale_ptr = nullptr; + } + if (v_descale_.has_value()) { + auto v_descale = v_descale_.value(); + CHECK_DEVICE(v_descale); + CHECK_SHAPE(v_descale, batch_size, num_heads_k); + params.v_descale_ptr = static_cast(v_descale.data_ptr()); + params.v_descale_batch_stride = v_descale.stride(0); + params.v_descale_head_stride = v_descale.stride(1); + } else { + params.v_descale_ptr = nullptr; + } + } + + #ifdef FLASHATTENTION_DISABLE_LOCAL + STD_TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); + #endif + #ifdef FLASHATTENTION_DISABLE_SOFTCAP + STD_TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping."); + #endif + #ifdef FLASHATTENTION_DISABLE_SPLIT + STD_TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits."); + #endif + #ifdef FLASHATTENTION_DISABLE_PACKGQA + STD_TORCH_CHECK(!params.pack_gqa || params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1, "This flash attention build does not support pack_gqa."); + #endif + #ifdef FLASHATTENTION_DISABLE_PAGEDKV + STD_TORCH_CHECK(!(params.page_table && !params.pagedkv_tma), "This flash attention build does not support paged KV."); + #endif + #ifdef FLASHATTENTION_DISABLE_APPENDKV + STD_TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV."); + #endif + + if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) { + auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_idx, &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); + run_mha_fwd(params, stream); + if (params.num_splits > 1) { + if (out_type == torch::headeronly::ScalarType::BFloat16) { + // Since we want output in BF16. Otherwise fwd_combine will output to FP16 + params.is_bf16 = true; + } + // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1 + // and seqlen = total_q, and don't need to dispatch to Varlen there. + // However, with dynamic split, each row needs to know which batch it belongs to + // to read the number of splits, so we just use the varlen version of combine kernel. + // if (is_varlen_q && !seqused_q_.has_value()) { + // if (is_varlen_q) { + // params.b = 1; + // params.seqlen_q = total_q; + // } + // This will zero out the semaphore if needed + run_mha_fwd_combine(params, stream, true /*enable_pdl*/); + } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) { + // need to zero out the semaphore in this case + auto slice = torch::stable::narrow(tile_count_semaphore, 0, params.tile_count_semaphore_offset, 1); + torch::stable::zero_(slice); + } + } else if (total_q > 0 && num_heads_k > 0) { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + torch::stable::zero_(out); + torch::stable::fill_(softmax_lse, std::numeric_limits::infinity()); + } + + // return {out, softmax_lse}; + return {out, softmax_lse, out_accum, softmax_lse_accum}; +} + +#ifdef FLASHATTENTION_DISABLE_BACKWARD +void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { + STD_TORCH_CHECK(false, "Flash-Attention was built with backward disabled"); +} +#else +template +void run_mha_bwd_constexpr(Flash_bwd_params ¶ms, cudaStream_t stream) { + if (!params.is_bf16) { + #ifndef FLASHATTENTION_DISABLE_FP16 + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d_rounded == 64) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d_rounded == 96) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d_rounded == 128) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d_rounded == 192) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d_rounded == 256) { return run_mha_bwd_(params, stream); } + #endif + #else + STD_TORCH_CHECK(false, "This flash attention build does not support FP16."); + #endif + } else { + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d_rounded == 64) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d_rounded == 96) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d_rounded == 128) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d_rounded == 192) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d_rounded == 256) { return run_mha_bwd_(params, stream); } + #endif + } +} + +void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { + // FP16_SWITCH(!params.is_bf16, [&] { + // HEADDIM_SWITCH(params.d, [&] { + // run_mha_bwd_(params, stream); + // }); + // }); + ARCH_SWITCH(params.arch, Arch, [&] { + SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] { + run_mha_bwd_constexpr(params, stream); + }); + }); +} +#endif + + +// b: batch_size +// s_q: seqlen_q +// s_k: seqlen_k +// h: num_heads +// h_k: num_heads_k +// d: head_size +std::tuple mha_bwd( + Tensor dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + Tensor k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + Tensor v, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k + Tensor out, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + Tensor softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q + std::optional dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + std::optional dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + std::optional dv_, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + std::optional max_seqlen_q_, + std::optional max_seqlen_k_, + std::optional softmax_scale_, + bool is_causal, + int64_t window_size_left, + int64_t window_size_right, + double softcap, + bool deterministic, + int64_t sm_margin +) { + + #ifdef FLASHATTENTION_DISABLE_BACKWARD + STD_TORCH_CHECK(false, "This flash attention build does not support backward."); + #endif + + auto dprops = get_device_prop(); + bool is_sm8x = dprops->major >= 8; + STD_TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + + auto q_type = q.scalar_type(); + STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16, + "FlashAttention only support fp16 and bf16 data type"); + STD_TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); + STD_TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); + STD_TORCH_CHECK(out.scalar_type() == q_type, "query and out must have the same dtype"); + STD_TORCH_CHECK(dout.scalar_type() == q_type, "query and dout must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + + STD_TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + STD_TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + + Tensor cu_seqlens_q; + bool const is_varlen_q = cu_seqlens_q_.has_value(); + if (is_varlen_q) { + cu_seqlens_q = cu_seqlens_q_.value(); + CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q); + STD_TORCH_CHECK(cu_seqlens_q.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_q must have dtype torch.int32"); + STD_TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided"); + } + Tensor cu_seqlens_k; + bool const is_varlen_k = cu_seqlens_k_.has_value(); + if (is_varlen_k) { + cu_seqlens_k = cu_seqlens_k_.value(); + CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k); + STD_TORCH_CHECK(cu_seqlens_k.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_k must have dtype torch.int32"); + STD_TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided"); + } + // This is what we will template on + bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value(); + #ifdef FLASHATTENTION_DISABLE_VARLEN + STD_TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); + #endif + + // auto const sizes = q.sizes(); + int const batch_size = !is_varlen_q ? q.size(0) : cu_seqlens_q.size(0) - 1; + int const seqlen_q = !is_varlen_q ? q.size(1) : max_seqlen_q_.value(); + int const total_q = !is_varlen_q ? batch_size * q.size(1) : q.size(0); + int const num_heads = q.size(-2); + int const head_size = q.size(-1); + int const head_size_v = v.size(-1); + int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value(); + int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); + int const num_heads_k = k.size(-2); + STD_TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + STD_TORCH_CHECK(head_size_v % 8 == 0, "head_size_v should be a multiple of 8"); + int const max_headdim = get_max_headdim(); + STD_TORCH_CHECK(std::max(head_size, head_size_v) <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); + STD_TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + double softmax_scale = 1.0 / sqrt(double(head_size)); + if (softmax_scale_.has_value()) { + softmax_scale = softmax_scale_.value(); + } + + // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM + if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } + if (window_size_right >= seqlen_q - 1) { window_size_right = -1; } + if (is_causal) { window_size_right = 0; } + // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_bprop will set params.is_causal=true. + // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM (and cause IMA). + is_causal = window_size_left < 0 && window_size_right == 0; + + int const arch = dprops->major * 10 + dprops->minor; + int const head_size_rounded = round_up_headdim(std::max(head_size, head_size_v)); + int const head_size_v_rounded = head_size_rounded; + STD_TORCH_CHECK(!deterministic || head_size_rounded < 256, "Deterministic backward not supported for hdim 256."); + // Very important that these match the kernel configs + bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal; + int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128) + : (head_size_rounded <= 96 ? 64 + : (head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.0 ? 64 : 80) + : 64)); + int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64; + int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32; + int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80); + int const kBlockN_sm90 = head_size_rounded <= 128 + ? 128 + : (head_size_rounded <= 192 ? 96 : 80); + int const kBlockN_sm80 = head_size_rounded <= 128 + ? 128 + : (head_size_rounded <= 192 ? 80 : 64); + int const kBlockN_sm86 = head_size_rounded <= 64 ? 128 + : (head_size_rounded <= 96 ? 128 + : (head_size_rounded <= 128 ? 96 + : (head_size_rounded <= 192 ? 64 : 64))); + int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80); + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM); + int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN); + int const total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM); + int const total_k_padded_rounded = round_multiple(total_k + batch_size * kBlockN, kBlockN); + + if (!is_varlen_q) { + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); + CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_v); + } else { + CHECK_SHAPE(q, total_q, num_heads, head_size); + CHECK_SHAPE(out, total_q, num_heads, head_size_v); + CHECK_SHAPE(dout, total_q, num_heads, head_size_v); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + } + if (!is_varlen_k) { + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_v); + } else { + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + } + + if (seqused_q_.has_value()){ + auto seqused_q = seqused_q_.value(); + STD_TORCH_CHECK(seqused_q.scalar_type() == torch::headeronly::ScalarType::Int, "seqused_q must have dtype int32"); + CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q); + CHECK_SHAPE(seqused_q, batch_size); + } + if (seqused_k_.has_value()){ + auto seqused_k = seqused_k_.value(); + STD_TORCH_CHECK(seqused_k.scalar_type() == torch::headeronly::ScalarType::Int, "seqused_k must have dtype int32"); + CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k); + CHECK_SHAPE(seqused_k, batch_size); + } + + Tensor dq, dk, dv; + if (dq_.has_value()) { + dq = dq_.value(); + STD_TORCH_CHECK(dq.scalar_type() == q_type, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + STD_TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + if (!is_varlen_q) { + CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); + } else { + CHECK_SHAPE(dq, total_q, num_heads, head_size); + } + } else { + dq = torch::stable::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + STD_TORCH_CHECK(dk.scalar_type() == q_type, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + STD_TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + if (!is_varlen_k) { + CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); + } else { + CHECK_SHAPE(dk, total_k, num_heads_k, head_size); + } + } else { + dk = torch::stable::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + STD_TORCH_CHECK(dv.scalar_type() == q_type, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + STD_TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + if (!is_varlen_k) { + CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_v); + } else { + CHECK_SHAPE(dv, total_k, num_heads_k, head_size_v); + } + } else { + dv = torch::stable::empty_like(v); + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + auto device_guard = make_device_guard(q); + + // auto opts = q.options(); + // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64 + Tensor softmax_d, softmax_lse_log2; + if (!is_varlen) { + // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64 + softmax_d = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + softmax_lse_log2 = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + } else { + softmax_d = torch::stable::new_empty(q, {num_heads, total_q_padded_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + softmax_lse_log2 = torch::stable::new_empty(q, {num_heads, total_q_padded_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + } + Tensor dq_accum, dk_accum, dv_accum; + if (!is_varlen) { + dq_accum = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q_rounded * head_size_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + } else { + dq_accum = torch::stable::new_empty(q, {num_heads, total_q_padded_rounded * head_size_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + } + if (num_heads_k != num_heads) { // MQA / GQA + if (!is_varlen) { + dk_accum = torch::stable::new_empty(q, {batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + dk_accum = torch::stable::fill_(dk_accum, 0.0); + dv_accum = torch::stable::new_empty(q, {batch_size, num_heads_k, seqlen_k_rounded * head_size_v_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + dv_accum = torch::stable::fill_(dv_accum, 0.0); + } else { + dk_accum = torch::stable::new_empty(q, {num_heads_k, total_k_padded_rounded, head_size_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + dk_accum = torch::stable::fill_(dk_accum, 0.0); + dv_accum = torch::stable::new_empty(q, {num_heads_k, total_k_padded_rounded, head_size_v_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); + dv_accum = torch::stable::fill_(dv_accum, 0.0); + } + } + + Flash_bwd_params params; + set_params_dgrad(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + dout, dq, dk, dv, + !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(), + !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(), + seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr, + seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr, + dq_accum.data_ptr(), + num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr, + num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr, + softmax_lse.data_ptr(), + softmax_d.data_ptr(), + /*p_dropout=*/0.f, + softmax_scale, + window_size_left, + window_size_right, + 0, // attention_chunk + softcap, + deterministic, + sm_margin); + params.total_q = total_q; + params.total_k = total_k; + params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr(); + params.dv = head_size_v; + params.dv_rounded = head_size_v_rounded; + + // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::headeronly::ScalarType::Int)) : torch::empty({1}, opts.dtype(torch::headeronly::ScalarType::Int)); + // params.tile_count_semaphore = static_cast(tile_count_semaphore.data_ptr()); + // Will be zero'ed out in the backward preprocess kernel + Tensor dq_semaphore = torch::stable::new_empty(q, {(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, std::make_optional(torch::headeronly::ScalarType::Int)); + params.dq_semaphore = static_cast(dq_semaphore.data_ptr()); + Tensor dk_semaphore, dv_semaphore; + if (num_heads_k != num_heads && params.deterministic) { + // TODO: maybe also zero'ed out dk_semaphore and dv_semaphore in the backward preprocess kernel + dk_semaphore = torch::stable::new_zeros(q, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, std::make_optional(torch::headeronly::ScalarType::Int)); + dv_semaphore = torch::stable::new_zeros(q, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, std::make_optional(torch::headeronly::ScalarType::Int)); + params.dk_semaphore = static_cast(dk_semaphore.data_ptr()); + params.dv_semaphore = static_cast(dv_semaphore.data_ptr()); + } + + #ifdef FLASHATTENTION_DISABLE_LOCAL + STD_TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); + #endif + #ifdef FLASHATTENTION_DISABLE_SOFTCAP + STD_TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping."); + #endif + + if (total_q > 0 && total_k > 0 && num_heads_k > 0) { + auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_idx, &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); + run_mha_bwd(params, stream); + } else if (total_k > 0 && num_heads_k > 0) { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + torch::stable::zero_(dk); + torch::stable::zero_(dv); + torch::stable::zero_(softmax_d); + } else if (total_q > 0 && num_heads_k > 0) { + torch::stable::zero_(dq); + torch::stable::zero_(softmax_d); + } + + return { softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum }; +} + +std::tuple +mha_combine(Tensor out_partial, // num_splits x batch_size x seqlen x num_heads x head_size + Tensor lse_partial, // num_splits x batch_size x seqlen x num_heads + std::optional out_, // batch_size x seqlen x num_heads x head_size + std::optional out_dtype_ + ) { + + auto dprops = get_device_prop(); + bool is_sm8x = dprops->major >= 8; + STD_TORCH_CHECK(is_sm8x, "Attention combine function only supports Ampere GPUs or newer."); + + auto out_partial_type = out_partial.scalar_type(); + STD_TORCH_CHECK(out_partial_type == torch::headeronly::ScalarType::Float, "Attention combine function only support fp32 data type"); + STD_TORCH_CHECK(lse_partial.scalar_type() == torch::headeronly::ScalarType::Float, "Attention combine function only support fp32 data type"); + + CHECK_DEVICE(out_partial); CHECK_DEVICE(lse_partial); + + STD_TORCH_CHECK(out_partial.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(lse_partial.stride(-2) == 1, "LSE tensor must be contiguous in the seqlen dimension"); + + // const auto sizes = out_partial.sizes(); + + const int num_splits = out_partial.size(0); + const int batch_size = out_partial.size(1); + const int seqlen = out_partial.size(2); + const int num_heads = out_partial.size(3); + const int head_size_og = out_partial.size(4); + STD_TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256"); + + CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og); + CHECK_SHAPE(lse_partial, num_splits, batch_size, seqlen, num_heads); + + int const alignment = 4; + Tensor out_partial_padded; + auto pad = [](Tensor x, int alignment) { + return x.size(-1) % alignment == 0 ? x : torch::stable::pad(x, {0, alignment - x.size(-1) % alignment}); + }; + out_partial_padded = pad(out_partial, alignment); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, alignment); + + // auto opts = out_partial.options(); + torch::headeronly::ScalarType out_type = out_dtype_.value_or(out_partial.scalar_type()); + STD_TORCH_CHECK(out_type == torch::headeronly::ScalarType::Float || out_type == torch::headeronly::ScalarType::BFloat16 || out_type == torch::headeronly::ScalarType::Half, "Output type must be FP32, FP16 or BF16"); + Tensor out; + if (out_.has_value()) { + out = out_.value(); + STD_TORCH_CHECK(out.scalar_type() == out_type); + CHECK_DEVICE(out); + STD_TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, seqlen, num_heads, head_size_og); + if (head_size_og % alignment != 0) { + out = torch::stable::new_empty(out_partial, {batch_size, seqlen, num_heads, head_size}, std::make_optional(out_type)); + } + } else { + out = torch::stable::new_empty(out_partial, {batch_size, seqlen, num_heads, head_size}, std::make_optional(out_type)); + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + auto device_guard = make_device_guard(out_partial); + + auto softmax_lse = torch::stable::new_empty(out_partial, {batch_size, num_heads, seqlen}, std::make_optional(torch::headeronly::ScalarType::Float)); + softmax_lse = torch::stable::transpose(softmax_lse, 1, 2); + + Flash_fwd_params params {}; // Need to reset the params to set everything to zero + params.is_fp32 = out_type == torch::headeronly::ScalarType::Float; + params.is_bf16 = out_type == torch::headeronly::ScalarType::BFloat16; + params.oaccum_ptr = out_partial_padded.data_ptr(); + params.softmax_lseaccum_ptr = lse_partial.data_ptr(); + params.o_ptr = out.data_ptr(); + params.softmax_lse_ptr = softmax_lse.data_ptr(); + params.b = batch_size; + params.h = num_heads; + params.seqlen_q = seqlen; + params.dv = head_size; + params.num_splits = num_splits; + params.oaccum_split_stride = out_partial_padded.stride(0); + params.oaccum_row_stride = out_partial_padded.stride(2); + params.oaccum_head_stride = out_partial_padded.stride(3); + params.oaccum_batch_stride = out_partial_padded.stride(1); + params.lseaccum_split_stride = lse_partial.stride(0); + params.lseaccum_head_stride = lse_partial.stride(3); + params.lseaccum_batch_stride = lse_partial.stride(1); + params.o_row_stride = out.stride(1); + params.o_head_stride = out.stride(2); + params.o_batch_stride = out.stride(0); + params.arch = dprops->major * 10 + dprops->minor; + + if (seqlen > 0 && batch_size > 0) { + auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_idx, &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); + run_mha_fwd_combine(params, stream, false /*enable_pdl*/); + } + + Tensor out_padded = out; + if (head_size_og % alignment != 0) { + out = torch::stable::narrow(out, -1, 0, head_size_og); + // if (out_.has_value()) { out_.value().copy_(out); } + } + + return {out, softmax_lse}; +} + +void boxed_mha_fwd( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs +) { + auto q = to(stack[0]); + auto k = to(stack[1]); + auto v = to(stack[2]); + auto k_new = to>(stack[3]); + auto v_new = to>(stack[4]); + auto q_v = to>(stack[5]); + auto out = to>(stack[6]); + auto cu_seqlens_q = to>(stack[7]); + auto cu_seqlens_k = to>(stack[8]); + auto cu_seqlens_k_new = to>(stack[9]); + auto seqused_q = to>(stack[10]); + auto seqused_k = to>(stack[11]); + auto max_seqlen_q = to>(stack[12]); + auto max_seqlen_k = to>(stack[13]); + auto page_table = to>(stack[14]); + auto kv_batch_idx = to>(stack[15]); + auto leftpad_k = to>(stack[16]); + auto rotary_cos = to>(stack[17]); + auto rotary_sin = to>(stack[18]); + auto seqlens_rotary = to>(stack[19]); + auto q_descale = to>(stack[20]); + auto k_descale = to>(stack[21]); + auto v_descale = to>(stack[22]); + auto softmax_scale = to>(stack[23]); + auto is_causal = to(stack[24]); + auto window_size_left = to(stack[25]); + auto window_size_right = to(stack[26]); + auto attention_chunk = to(stack[27]); + auto softcap = to(stack[28]); + auto is_rotary_interleaved = to(stack[29]); + auto scheduler_metadata = to>(stack[30]); + auto num_splits = to(stack[31]); + auto pack_gqa = to>(stack[32]); + auto sm_margin = to(stack[33]); + + auto [out_, softmax_lse, out_accum, softmax_lse_accum] = mha_fwd(q, k, v, k_new, v_new, q_v, out, cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, seqlens_rotary, q_descale, k_descale, v_descale, softmax_scale, is_causal, window_size_left, window_size_right, attention_chunk, softcap, is_rotary_interleaved, scheduler_metadata, num_splits, pack_gqa, sm_margin); + + + stack[0] = from(out_); + stack[1] = from(softmax_lse); + stack[2] = from(out_accum); + stack[3] = from(softmax_lse_accum); +} + +void boxed_mha_bwd( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs +) { + auto dout = to(stack[0]); + auto q = to(stack[1]); + auto k = to(stack[2]); + auto v = to(stack[3]); + auto out = to(stack[4]); + auto softmax_lse = to(stack[5]); + auto dq = to>(stack[6]); + auto dk = to>(stack[7]); + auto dv = to>(stack[8]); + auto cu_seqlens_q = to>(stack[9]); + auto cu_seqlens_k = to>(stack[10]); + auto seqused_q = to>(stack[11]); + auto seqused_k = to>(stack[12]); + auto max_seqlen_q = to>(stack[13]); + auto max_seqlen_k = to>(stack[14]); + auto softmax_scale = to>(stack[15]); + auto is_causal = to(stack[16]); + auto window_size_left = to(stack[17]); + auto window_size_right = to(stack[18]); + auto softcap = to(stack[19]); + auto deterministic = to(stack[20]); + auto sm_margin = to(stack[21]); + + auto [softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum] = mha_bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, softmax_scale, is_causal, window_size_left, window_size_right, softcap, deterministic, sm_margin); + + stack[0] = from(softmax_d); + stack[1] = from(softmax_lse_log2); + stack[2] = from(dq_accum); + stack[3] = from(dk_accum); + stack[4] = from(dv_accum); +} + +void boxed_mha_combine( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs +) { + auto out_partial = to(stack[0]); + auto lse_partial = to(stack[1]); + auto out = to>(stack[2]); + auto out_dtype = to>(stack[3]); + + auto [out_, softmax_lse] = mha_combine(out_partial, lse_partial, out, out_dtype); + + stack[0] = from(out_); + stack[1] = from(softmax_lse); +} + +void boxed_mha_fwd_get_scheduler_metadata( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs +) { + auto batch_size = to(stack[0]); + auto max_seqlen_q = to(stack[1]); + auto max_seqlen_k = to(stack[2]); + auto num_heads = to(stack[3]); + auto num_heads_k = to(stack[4]); + auto headdim = to(stack[5]); + auto headdim_v = to(stack[6]); + auto qkv_dtype = to(stack[7]); + auto seqused_k = to(stack[8]); + auto cu_seqlens_q = to>(stack[9]); + auto cu_seqlens_k = to>(stack[10]); + auto cu_seqlens_k_new = to>(stack[11]); + auto seqused_q = to>(stack[12]); + auto leftpad_k = to>(stack[13]); + auto page_size = to>(stack[14]); + auto max_seqlen_k_new = to(stack[15]); + auto is_causal = to(stack[16]); + auto window_size_left = to(stack[17]); + auto window_size_right = to(stack[18]); + auto attention_chunk = to(stack[19]); + auto has_softcap = to(stack[20]); + auto num_splits = to(stack[21]); + auto pack_gqa = to>(stack[22]); + auto sm_margin = to(stack[23]); + + auto scheduler_metadata = mha_fwd_get_scheduler_metadata(batch_size, max_seqlen_q, max_seqlen_k, num_heads, num_heads_k, headdim, headdim_v, qkv_dtype, seqused_k, cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new, seqused_q, leftpad_k, page_size, max_seqlen_k_new, is_causal, window_size_left, window_size_right, attention_chunk, has_softcap, num_splits, pack_gqa, sm_margin); + + stack[0] = from(scheduler_metadata); +} + +STABLE_TORCH_LIBRARY(flash_attn_3, m) { + m.def("fwd(" + "Tensor q," + "Tensor k," + "Tensor v," + "Tensor(k_new!)? k_new = None," + "Tensor(v_new!)? v_new = None," + "Tensor? q_v = None," + "Tensor(out!)? out = None," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? cu_seqlens_k_new = None," + "Tensor? seqused_q = None," + "Tensor? seqused_k = None," + "int? max_seqlen_q = None," + "int? max_seqlen_k = None," + "Tensor? page_table = None," + "Tensor? kv_batch_idx = None," + "Tensor? leftpad_k = None," + "Tensor? rotary_cos = None," + "Tensor? rotary_sin = None," + "Tensor? seqlens_rotary = None," + "Tensor? q_descale = None," + "Tensor? k_descale = None," + "Tensor? v_descale = None," + "float? softmax_scale = None," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + "int attention_chunk = 0," + "float softcap = 0.0," + "bool is_rotary_interleaved = False," + "Tensor? scheduler_metadata = None," + "int num_splits = 0," + "bool? pack_gqa = None," + "int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)"); + m.def("bwd(" + "Tensor dout," + "Tensor q," + "Tensor k," + "Tensor v," + "Tensor out," + "Tensor softmax_lse," + "Tensor(dq!)? dq = None," + "Tensor(dk!)? dk = None," + "Tensor(dv!)? dv = None," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? seqused_q = None," + "Tensor? seqused_k = None," + "int? max_seqlen_q = None," + "int? max_seqlen_k = None," + "float? softmax_scale = None," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + "float softcap = 0.0," + "bool deterministic = False," + "int sm_margin = 0) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"); + m.def("fwd_combine(" + "Tensor out_partial," + "Tensor lse_partial," + "Tensor(out!)? out = None," + "ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)"); + m.def("get_scheduler_metadata(" + "int batch_size," + "int max_seqlen_q," + "int max_seqlen_k," + "int num_heads," + "int num_heads_k," + "int headdim," + "int headdim_v," + "ScalarType qkv_dtype," + "Tensor seqused_k," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? cu_seqlens_k_new = None," + "Tensor? seqused_q = None," + "Tensor? leftpad_k = None," + "int? page_size = None," + "int max_seqlen_k_new = 0," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + "int attention_chunk = 0," + "bool has_softcap = False," + "int num_splits = 0," + "bool? pack_gqa = None," + "int sm_margin = 0) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(flash_attn_3, CUDA, m) { + m.impl("fwd", &boxed_mha_fwd); + m.impl("bwd", &boxed_mha_bwd); + m.impl("fwd_combine", &boxed_mha_combine); + m.impl("get_scheduler_metadata", &boxed_mha_fwd_get_scheduler_metadata); +} diff --git a/hopper/flash_api_torch_lib.cpp b/hopper/flash_api_torch_lib.cpp index 338c9d408bf..252f9f01ded 100644 --- a/hopper/flash_api_torch_lib.cpp +++ b/hopper/flash_api_torch_lib.cpp @@ -17,73 +17,75 @@ // h: num_heads // h_k: num_heads_k // d: head_size -std::vector -mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor &k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. - const at::Tensor &v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. - std::optional &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - std::optional &v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new - std::optional &q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q - std::optional &out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q - std::optional &cu_seqlens_q_, // b+1 - std::optional &cu_seqlens_k_, // b+1 - std::optional &cu_seqlens_k_new_, // b+1 - std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional &seqused_k_, // b. If given, only this many elements of each batch element's keys are used. - std::optional max_seqlen_q_, +std::tuple +mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + at::Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. + at::Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. + std::optional k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new + std::optional v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + std::optional out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional cu_seqlens_k_new_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + std::optional max_seqlen_q_, // TODO: check if we need max_seqlen_k - std::optional max_seqlen_k_, - std::optional &page_table_, // (b_k, max_num_pages_per_seq) - std::optional &kv_batch_idx_, // b. indices to index into the KV cache - std::optional &leftpad_k_, // b - std::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) - std::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) - std::optional &seqlens_rotary_, // b - std::optional &q_descale_, // (b, h_k), not (b, h) - std::optional &k_descale_, // (b, h_k) - std::optional &v_descale_, // (b, h_k) - float const softmax_scale, + std::optional max_seqlen_k_, + std::optional page_table_, // (b_k, max_num_pages_per_seq) + std::optional kv_batch_idx_, // b. indices to index into the KV cache + std::optional leftpad_k_, // b + std::optional rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional seqlens_rotary_, // b + std::optional q_descale_, // (b, h_k), not (b, h) + std::optional k_descale_, // (b, h_k) + std::optional v_descale_, // (b, h_k) + std::optional softmax_scale_, bool is_causal, - int window_size_left, - int window_size_right, - float const softcap, - bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 - std::optional &scheduler_metadata_, // (b + 1) - int num_splits, + int64_t window_size_left, + int64_t window_size_right, + int64_t attention_chunk, + double softcap, + bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + std::optional scheduler_metadata_, // (b + 1) + int64_t num_splits, std::optional pack_gqa_, - int const sm_margin, - std::optional &s_aux_, - int const cp_world_size, - int const cp_rank, - std::optional &cp_tot_seqused_k + int64_t sm_margin, + std::optional s_aux_, + int64_t cp_world_size, + int64_t cp_rank, + std::optional cp_tot_seqused_k ); // Only applicable to the case where seqused_k (i.e. cache_seqlens) is available at::Tensor mha_fwd_get_scheduler_metadata( - int batch_size, - int max_seqlen_q, - int max_seqlen_k, - int num_heads, - int num_heads_k, - int headdim, - int headdim_v, + int64_t batch_size, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + int64_t num_heads, + int64_t num_heads_k, + int64_t headdim, + int64_t headdim_v, at::ScalarType qkv_dtype, - const at::Tensor &seqused_k, // b - std::optional &cu_seqlens_q_, // b+1 - std::optional &cu_seqlens_k_, // b+1 - std::optional &cu_seqlens_k_new_, // b+1 - std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional &leftpad_k_, // b - std::optional page_size, - int max_seqlen_k_new, // 0 means we're not appending new KV + at::Tensor seqused_k, // b + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional cu_seqlens_k_new_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional leftpad_k_, // b + std::optional page_size, + int64_t max_seqlen_k_new, // 0 means we're not appending new KV bool is_causal, - int window_size_left, - int window_size_right, + int64_t window_size_left, + int64_t window_size_right, + int64_t attention_chunk, bool has_softcap, - int num_splits, + int64_t num_splits, std::optional pack_gqa_, - int const sm_margin + int64_t sm_margin ); /** @@ -113,10 +115,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor? q_descale," " Tensor? k_descale," " Tensor? v_descale," - " float softmax_scale," + " float? softmax_scale," " bool is_causal," " int window_size_left," " int window_size_right," + " int attention_chunk," " float softcap," " bool is_rotary_interleaved," " Tensor? scheduler_metadata," @@ -126,7 +129,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor? s_aux," " int cp_world_size," " int cp_rank," - " Tensor? cp_tot_seqused_k) -> Tensor[]"); + " Tensor? cp_tot_seqused_k) -> (Tensor, Tensor, Tensor, Tensor)"); ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); ops.def("get_scheduler_metadata(" @@ -149,6 +152,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " bool is_causal," " int window_size_left," " int window_size_right," + " int attention_chunk," " bool has_softcap," " int num_splits," " bool? pack_gqa," diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 6f06f96c830..0c57aca1649 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -1,58 +1,84 @@ # Copyright (c) 2023, Tri Dao. -from typing import Optional, Union +from typing import Optional, Union, List, Tuple import torch import torch.nn as nn # isort: off # We need to import the CUDA kernels after importing torch -import flash_attn_3_cuda +import flash_attn_3._C # Registers operators with PyTorch # isort: on +flash_attn_3_cuda = torch.ops.flash_attn_3 def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x +def round_multiple(x, m): + return (x + m - 1) // m * m + + +def round_up_headdim(head_size: int) -> int: + from flash_attn_config import CONFIG + + if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]: + if head_size <= 64: + return 64 + if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]: + if head_size <= 96: + return 96 + if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]: + if head_size <= 128: + return 128 + if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]: + if head_size <= 192: + return 192 + if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]: + if head_size <= 256: + return 256 + return 256 + + +@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") def _flash_attn_forward( - q, - k, - v, - k_new, - v_new, - qv, - out, - cu_seqlens_q, - cu_seqlens_k, - cu_seqlens_k_new, - seqused_q, - seqused_k, - max_seqlen_q, - max_seqlen_k, - page_table, - kv_batch_idx, - leftpad_k, - rotary_cos, - rotary_sin, - seqlens_rotary, - q_descale, - k_descale, - v_descale, - softmax_scale, - causal, - window_size=(-1, -1), - softcap=0.0, - rotary_interleaved=True, - scheduler_metadata=None, - num_splits=1, - pack_gqa=None, - sm_margin=0, - s_aux=None, - cp_world_size=1, - cp_rank=0, - cp_tot_seqused_k=None): + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_new: Optional[torch.Tensor] = None, + v_new: Optional[torch.Tensor] = None, + qv: Optional[torch.Tensor] = None, + out_: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + page_table: Optional[torch.Tensor] = None, + kv_batch_idx: Optional[torch.Tensor] = None, + leftpad_k: Optional[torch.Tensor] = None, + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, + seqlens_rotary: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size_left: int = -1, + window_size_right: int = -1, + attention_chunk: int = 0, + softcap: float = 0.0, + rotary_interleaved: bool = True, + scheduler_metadata: Optional[torch.Tensor] = None, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + sm_margin: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)] v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [ @@ -64,14 +90,14 @@ def _flash_attn_forward( ] rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] seqlens_rotary = maybe_contiguous(seqlens_rotary) - out, softmax_lse, *rest = flash_attn_3_cuda.fwd( + out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_cuda.fwd( q, k, v, k_new, v_new, qv, - out, + out_, cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new, @@ -90,8 +116,9 @@ def _flash_attn_forward( v_descale, softmax_scale, causal, - window_size[0], - window_size[1], + window_size_left, + window_size_right, + attention_chunk, softcap, rotary_interleaved, scheduler_metadata, @@ -103,59 +130,314 @@ def _flash_attn_forward( cp_rank, cp_tot_seqused_k, ) - return out, softmax_lse, *rest + if out_accum is None: + out_accum = torch.tensor([], device=out.device) + + if softmax_lse_accum is None: + softmax_lse_accum = torch.tensor([], device=out.device) + + return out, softmax_lse, out_accum, softmax_lse_accum + + +@torch.library.register_fake("flash_attn_3::_flash_attn_forward") +def _flash_attn_forward_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_new: Optional[torch.Tensor] = None, + v_new: Optional[torch.Tensor] = None, + qv: Optional[torch.Tensor] = None, + out_: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + page_table: Optional[torch.Tensor] = None, + kv_batch_idx: Optional[torch.Tensor] = None, + leftpad_k: Optional[torch.Tensor] = None, + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, + seqlens_rotary: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size_left: int = -1, + window_size_right: int = -1, + attention_chunk: int = 0, + softcap: float = 0.0, + rotary_interleaved: bool = True, + scheduler_metadata: Optional[torch.Tensor] = None, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + sm_margin: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Symbolic fake implementation of flash attention forward. + Returns tensors with the correct shapes and dtypes without actual computation. + """ + + # Determine if we're in varlen mode + is_varlen_q = cu_seqlens_q is not None + # Get dimensions from query tensor + if is_varlen_q: + # varlen mode: q is (total_q, num_heads, head_size) + total_q, num_heads, head_size = q.shape + batch_size = cu_seqlens_q.shape[0] - 1 + + if max_seqlen_q is None: + raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided") + seqlen_q = max_seqlen_q + else: + # batch mode: q is (batch_size, seqlen_q, num_heads, head_size) + batch_size, seqlen_q, num_heads, head_size = q.shape + total_q = batch_size * q.shape[1] + # Get value head dimension + head_size_v = v.shape[-1] + + # Determine output dtype (FP8 inputs produce BF16 outputs) + q_type = q.dtype + if q_type == torch.float8_e4m3fn: + out_dtype = torch.bfloat16 + else: + out_dtype = q_type + + # Create output tensor + if out_ is not None: + # If out_ is provided, _flash_attn_forward becomes non-functional + raise TypeError("Tracing (torch.compile/torch.export) with pre-allocated output tensor is not supported.") + + if is_varlen_q: + out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device) + else: + out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device) + + # Create softmax_lse tensor + if is_varlen_q: + softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device) + else: + softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device) + + # TODO(guilhermeleobas): Implement "get_num_splits" + # There's an heuristic to compute num_splits when "num_splits <= 0" + # assert that num_splits is > 0 for now + if num_splits <= 0: + raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}") + + if num_splits > 1: + if is_varlen_q: + out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device) + softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device) + else: + out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device) + softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device) + else: + # Tensors are not set when num_splits < 1 + out_accum = torch.tensor([], device=out.device) + softmax_lse_accum = torch.tensor([], device=out.device) + + return out, softmax_lse, out_accum, softmax_lse_accum + + +@torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") def _flash_attn_backward( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + sequed_q: Optional[torch.Tensor] = None, + sequed_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + is_causal: bool = False, + window_size_left: int = -1, + window_size_right: int = -1, + softcap: float = 0.0, + deterministic: bool = False, + sm_margin: int = 0, +) -> torch.Tensor: + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + softmax_d, *rest = flash_attn_3_cuda.bwd( dout, q, k, v, out, softmax_lse, + dq, + dk, + dv, cu_seqlens_q, cu_seqlens_k, sequed_q, sequed_k, max_seqlen_q, max_seqlen_k, - dq, - dk, - dv, softmax_scale, - causal, - window_size=(-1, -1), - softcap=0.0, - deterministic=False, - sm_margin=0, -): - # dq, dk, dv are allocated by us so they should already be contiguous - dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd( + is_causal, + window_size_left, + window_size_right, + softcap, + deterministic, + sm_margin, + ) + return softmax_d + + +@torch.library.register_fake("flash_attn_3::_flash_attn_backward") +def _flash_attn_backward_fake( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + sequed_q: Optional[torch.Tensor] = None, + sequed_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + is_causal: bool = False, + window_size_left: int = -1, + window_size_right: int = -1, + softcap: float = 0.0, + deterministic: bool = False, + sm_margin: int = 0, +) -> torch.Tensor: + + is_varlen_q = cu_seqlens_q is not None + is_varlen_k = cu_seqlens_q is not None + is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None + + if not is_varlen_q: + batch_size = q.size(0) + seqlen_q = q.size(1) + seqlen_k = k.size(1) + total_q = batch_size * q.size(1) + else: + batch_size = cu_seqlens_q.size(0) - 1 + total_q = q.size(0) + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + + if window_size_left >= seqlen_k - 1: + window_size_left = -1 + + if window_size_right >= seqlen_q - 1: + window_size_right = -1 + + if is_causal: + window_size_right = 0 + + is_causal = window_size_left < 0 and window_size_right == 0 + + head_size = q.size(-1) + head_size_v = v.size(-1) + head_size_rounded = round_up_headdim(max(head_size, head_size_v)) + + # Hopper gpus uses cuda compute capabilities 9.0 + cap = torch.cuda.get_device_capability(q.device) + arch = cap[0] * 10 + cap[1] + + is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal + + if head_size_rounded <= 64: + kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128 + elif head_size_rounded <= 96: + kBlockM_sm90 = 64 + elif head_size_rounded <= 128: + kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80 + else: + kBlockM_sm90 = 64 + + kBlockM_sm80 = 128 if head_size_rounded <= 64 else 64 + kBlockM_sm86 = 64 if head_size_rounded <= 192 else 32 + + if arch >= 90: + kBlockM = kBlockM_sm90 + elif arch == 86 or arch == 89: + kBlockM = kBlockM_sm86 + else: + kBlockM = kBlockM_sm80 + + num_heads = q.shape[-2] + seqlen_q_rounded = round_multiple(seqlen_q, kBlockM) + + total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM) + + dq = torch.empty_like(q) if dq is None else dq + dk = torch.empty_like(k) if dk is None else dk + dv = torch.empty_like(v) if dv is None else dv + + if not is_varlen: + softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device) + else: + softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device) + + return softmax_d + + +def setup_context(ctx, inputs, output): + q, k, v = inputs[:3] + out, softmax_lse, _, _ = output + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.softmax_scale = inputs[-11] + ctx.causal = inputs[-10] + ctx.window_size = [inputs[-9], inputs[-8]] + ctx.attention_chunk = inputs[-7] + ctx.softcap = inputs[-6] + ctx.sm_margin = inputs[-1] + + +def _backward(ctx, dout, *grads): + q, k, v, out, softmax_lse = ctx.saved_tensors + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + _flash_attn_backward( dout, q, k, v, out, softmax_lse, + None, None, # cu_seqlens_q, cu_seqlens_k, + None, None, # sequed_q, sequed_k, + None, None, # max_seqlen_q, max_seqlen_k, dq, dk, dv, - cu_seqlens_q, - cu_seqlens_k, - sequed_q, - sequed_k, - max_seqlen_q, - max_seqlen_k, - softmax_scale, - causal, - window_size[0], - window_size[1], - softcap, - deterministic, - sm_margin, + ctx.softmax_scale, + ctx.causal, + ctx.window_size[0], + ctx.window_size[1], + ctx.softcap, + False, # deterministic + ctx.sm_margin, ) - return dq, dk, dv, softmax_d + return dq, dk, dv, *((None,) * 21) + + +_flash_attn_forward.register_autograd(_backward, setup_context=setup_context) + class FlashAttnQKVPackedFunc(torch.autograd.Function): @@ -167,9 +449,12 @@ def forward( causal, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), + attention_chunk=0, softcap=0.0, deterministic=False, num_heads_q=None, + sm_margin=0, + return_softmax=False, ): if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) @@ -197,23 +482,28 @@ def forward( q_descale, k_descale, v_descale, softmax_scale, causal=causal, - window_size=window_size, + window_size_left=window_size[0], + window_size_right=window_size[1], + attention_chunk=attention_chunk, softcap=softcap, + sm_margin=sm_margin, ) # ctx.save_for_backward(q, k, v, out_padded, softmax_lse) ctx.save_for_backward(q, k, v, out, softmax_lse) ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size + ctx.attention_chunk = attention_chunk ctx.softcap = softcap ctx.deterministic = deterministic ctx.ndim = qkv.dim() - # return out, softmax_lse - return out + ctx.sm_margin = sm_margin + return (out, softmax_lse) if return_softmax else out @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse = ctx.saved_tensors + assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk" if ctx.ndim == 5: qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) @@ -239,12 +529,14 @@ def backward(ctx, dout, *args): dv, ctx.softmax_scale, ctx.causal, - ctx.window_size, + ctx.window_size[0], + ctx.window_size[1], ctx.softcap, - ctx.deterministic, + ctx.deterministic, + ctx.sm_margin, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None, None + return dqkv, None, None, None, None, None, None, None, None, None, None, None, None class FlashAttnFunc(torch.autograd.Function): @@ -260,15 +552,13 @@ def forward( qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), + attention_chunk=0, softcap=0.0, num_splits=1, pack_gqa=None, deterministic=False, sm_margin=0, - s_aux=None, - cp_world_size=1, - cp_rank=0, - cp_tot_seqused_k=None, + return_softmax=False, ): if softmax_scale is None: softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) @@ -288,7 +578,9 @@ def forward( q_descale, k_descale, v_descale, softmax_scale, causal=causal, - window_size=window_size, + window_size_left=window_size[0], + window_size_right=window_size[1], + attention_chunk=attention_chunk, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, @@ -303,14 +595,16 @@ def forward( ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size + ctx.attention_chunk = attention_chunk ctx.softcap = softcap ctx.deterministic = deterministic ctx.sm_margin = sm_margin - return out, softmax_lse + return (out, softmax_lse) if return_softmax else out @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse = ctx.saved_tensors + assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk" dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) _flash_attn_backward( dout, @@ -327,14 +621,15 @@ def backward(ctx, dout, *args): dv, ctx.softmax_scale, ctx.causal, - ctx.window_size, + ctx.window_size[0], + ctx.window_size[1], ctx.softcap, ctx.deterministic, ctx.sm_margin, ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] + dq = dq[..., : q.shape[-1]] # We could have padded the head dimension + dk = dk[..., : k.shape[-1]] + dv = dv[..., : v.shape[-1]] return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None @@ -357,15 +652,13 @@ def forward( qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), + attention_chunk=0, softcap=0.0, num_splits=1, pack_gqa=None, deterministic=False, sm_margin=0, - s_aux=None, - cp_world_size=1, - cp_rank=0, - cp_tot_seqused_k=0, + return_softmax=False, ): if softmax_scale is None: softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) @@ -389,7 +682,9 @@ def forward( q_descale, k_descale, v_descale, softmax_scale, causal=causal, - window_size=window_size, + window_size_left=window_size[0], + window_size_right=window_size[1], + attention_chunk=attention_chunk, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, @@ -406,14 +701,16 @@ def forward( ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size + ctx.attention_chunk = attention_chunk ctx.softcap = softcap ctx.deterministic = deterministic ctx.sm_margin = sm_margin - return out, softmax_lse + return (out, softmax_lse) if return_softmax else out @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors + assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk" dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) _flash_attn_backward( dout, @@ -433,14 +730,15 @@ def backward(ctx, dout, *args): dv, ctx.softmax_scale, ctx.causal, - ctx.window_size, + ctx.window_size[0], + ctx.window_size[1], ctx.softcap, ctx.deterministic, ctx.sm_margin, ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] + dq = dq[..., : q.shape[-1]] # We could have padded the head dimension + dk = dk[..., : k.shape[-1]] + dv = dv[..., : v.shape[-1]] return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None @@ -450,9 +748,12 @@ def flash_attn_qkvpacked_func( causal=False, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), + attention_chunk=0, softcap=0.0, deterministic=False, num_heads_q=None, + sm_margin=0, + return_attn_probs=False, ): """dropout_p should be set to 0.0 during evaluation If Q, K, V are already stacked into 1 tensor, this function will be faster than @@ -494,9 +795,12 @@ def flash_attn_qkvpacked_func( causal, q_descale, k_descale, v_descale, window_size, + attention_chunk, softcap, deterministic, num_heads_q, + sm_margin, + return_attn_probs, ) @@ -509,15 +813,13 @@ def flash_attn_func( qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), + attention_chunk=0, softcap=0.0, num_splits=1, pack_gqa=None, deterministic=False, sm_margin=0, - s_aux=None, - cp_world_size=1, - cp_rank=0, - cp_tot_seqused_k=None, + return_attn_probs=False, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads @@ -573,15 +875,13 @@ def flash_attn_func( qv, q_descale, k_descale, v_descale, window_size, + attention_chunk, softcap, num_splits, pack_gqa, deterministic, sm_margin, - s_aux, - cp_world_size, - cp_rank, - cp_tot_seqused_k, + return_attn_probs, ) @@ -600,15 +900,13 @@ def flash_attn_varlen_func( qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), + attention_chunk=0, softcap=0.0, num_splits=1, pack_gqa=None, deterministic=False, sm_margin=0, - s_aux=None, - cp_world_size=1, - cp_rank=0, - cp_tot_seqused_k=None, + return_attn_probs=False, ): return FlashAttnVarlenFunc.apply( q, @@ -625,15 +923,13 @@ def flash_attn_varlen_func( qv, q_descale, k_descale, v_descale, window_size, + attention_chunk, softcap, num_splits, pack_gqa, deterministic, sm_margin, - s_aux, - cp_world_size, - cp_rank, - cp_tot_seqused_k, + return_attn_probs, ) @@ -664,6 +960,7 @@ def flash_attn_with_kvcache( softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window + attention_chunk=0, softcap=0.0, # 0.0 means deactivated rotary_interleaved=True, scheduler_metadata=None, @@ -722,7 +1019,7 @@ def flash_attn_with_kvcache( q: (batch_size, seqlen, nheads, headdim) k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table, or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) - page_block_size must be a multiple of 256. + page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.). v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table, or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache) k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate @@ -767,7 +1064,7 @@ def flash_attn_with_kvcache( softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) if cache_seqlens is not None and isinstance(cache_seqlens, int): cache_seqlens = torch.full( - (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device + (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device ) cache_seqlens = maybe_contiguous(cache_seqlens) out, softmax_lse, *rest = _flash_attn_forward( @@ -794,7 +1091,9 @@ def flash_attn_with_kvcache( q_descale, k_descale, v_descale, softmax_scale, causal=causal, - window_size=window_size, + window_size_left=window_size[0], + window_size_right=window_size[1], + attention_chunk=attention_chunk, softcap=softcap, rotary_interleaved=rotary_interleaved, scheduler_metadata=scheduler_metadata, @@ -822,6 +1121,7 @@ def get_scheduler_metadata( max_seqlen_k_new=0, causal=False, window_size=(-1, -1), # -1 means infinite context window + attention_chunk=0, has_softcap=False, num_splits=0, # Can be tuned for speed pack_gqa=None, # Can be tuned for speed @@ -843,6 +1143,7 @@ def get_scheduler_metadata( max_seqlen_k_new, causal, window_size[0], window_size[1], + attention_chunk, has_softcap, num_splits, pack_gqa, diff --git a/hopper/flash_bwd_launch_template.h b/hopper/flash_bwd_launch_template.h index 76ded0407ec..6df3231cdd4 100644 --- a/hopper/flash_bwd_launch_template.h +++ b/hopper/flash_bwd_launch_template.h @@ -49,7 +49,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { using PreprocessKernel = flash::FlashAttnBwdPreprocess; typename PreprocessKernel::Arguments preprocess_args { static_cast(params.o_ptr), - {seqlen_q, params.d, params.h, batch_q}, // shape_O + {seqlen_q, params.dv, params.h, batch_q}, // shape_O {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0}, // stride_O static_cast(params.do_ptr), {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO @@ -93,7 +93,11 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { flash::CollectiveEpilogueBwd= 90 ? 1 : cutlass::NumWarpsPerWarpGroup) / AtomLayoutNdKV>, flash::CollectiveEpilogueBwdGQA >; - using Scheduler = flash::SingleTileScheduler; + using Scheduler = std::conditional_t< + Is_causal, + flash::SingleTileBwdLPTScheduler, + flash::SingleTileScheduler + >; using AttnKernel = std::conditional_t< Arch >= 90, flash::enable_sm90_or_later>, @@ -108,8 +112,10 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { {seqlen_k, params.d, params.h_k, batch_k}, // shape_K {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K static_cast(params.v_ptr), + {seqlen_k, params.dv, params.h_k, batch_k}, // shape_V {params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0}, // stride_V static_cast(params.do_ptr), + {seqlen_q, params.dv, params.h, batch_q}, // shape_dO {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO static_cast(params.dq_accum_ptr), {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum @@ -120,7 +126,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { static_cast(params.dsoftmax_sum), {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum params.scale_softmax, - params.window_size_left, params.window_size_right, + params.window_size_left, params.window_size_right, 0 /*attention_chunk*/, params.softcap, params.b, params.dq_semaphore, @@ -145,13 +151,21 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { } }(), static_cast(!GQA ? params.dv_ptr : params.dv_accum_ptr), + [&] { + if constexpr (!GQA) { + return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.dv, params.h, batch_k}; // shape_dV + } else { + return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}; // shape_dVaccum + } + }(), [&] { if constexpr (!GQA) { return typename CollectiveEpilogue::StridedKV {params.dv_row_stride, _1{}, params.dv_head_stride, !is_varlen_k ? params.dv_batch_stride : 0}; // stride_dV } else { - return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum + return typename CollectiveEpilogue::StridedKV {_1{}, params.dv_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.dv_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum } }(), + params.b, params.h, params.dk_semaphore, params.dv_semaphore, @@ -256,10 +270,10 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { typename PostprocessKerneldKV::Params postprocess_dK_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dK_args); typename PostprocessKerneldKV::Arguments postprocess_dV_args { static_cast(params.dv_accum_ptr), - {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dVaccum - {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum + {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}, // shape_dVaccum + {_1{}, seqlen_k_rounded * params.dv_rounded, !is_varlen_k ? params.dv_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum static_cast(params.dv_ptr), - {seqlen_k, params.d, params.h_k, batch_k}, // shape_dV + {seqlen_k, params.dv, params.h_k, batch_k}, // shape_dV {params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride}, // stride_dV 1.f, params.cu_seqlens_k, @@ -288,10 +302,11 @@ template(params, stream); - run_flash_bwd(params, stream); -// }); + BOOL_SWITCH(params.deterministic, Deterministic_, [&] { + static constexpr bool Deterministic = Deterministic_ && kHeadDim < 256; + // run_flash_bwd(params, stream); + run_flash_bwd(params, stream); + }); }); }); } diff --git a/hopper/flash_fwd_combine_kernel.h b/hopper/flash_fwd_combine_kernel.h index cb3f6dbebcb..ed106de4562 100644 --- a/hopper/flash_fwd_combine_kernel.h +++ b/hopper/flash_fwd_combine_kernel.h @@ -139,7 +139,6 @@ class FlashAttnFwdCombine { // Device side arguments struct Arguments { - int b; ElementPartial const* const ptr_O_partial; ShapeOPartial const shape_O_partial; StrideOPartial const stride_O_partial; @@ -159,7 +158,6 @@ class FlashAttnFwdCombine { // Kernel entry point API struct CollectiveParams { - int b; ElementPartial const* const ptr_O_partial; ShapeOPartial const shape_O_partial; StrideOPartial const stride_O_partial; @@ -184,7 +182,6 @@ class FlashAttnFwdCombine { to_underlying_arguments(Arguments const& args) { assert(get<1>(args.shape_LSE_partial) <= kMaxSplits); return { - args.b, args.ptr_O_partial, args.shape_O_partial, args.stride_O_partial, @@ -434,24 +431,24 @@ class FlashAttnFwdCombine { Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{}); int const thread_idx = threadIdx.x; - - BlockCoord block_coord = tile_scheduler.get_block_coord(kernel_params.scheduler_params); - - int const m_block = block_coord.block_m; - int const k_block = block_coord.block_k; - int const maybe_virtual_batch = block_coord.bidb; - if (maybe_virtual_batch >= params.b) { return; } + int const m_block = blockIdx.x; + int const k_block = blockIdx.y; + int const maybe_virtual_batch = blockIdx.z; int const batch = params.varlen_batch_idx_ptr ? params.varlen_batch_idx_ptr[maybe_virtual_batch] : maybe_virtual_batch; - + int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[maybe_virtual_batch] : get<1>(params.shape_LSE_partial); + + if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) { + cutlass::arch::wait_on_dependent_grids(); + *params.semaphore_to_reset = 0; + } + if (num_splits <= 1) { return; } flash::SeqlenInfo seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused}; int const offset = seqlen_info.offset; int const seqlen = seqlen_info.seqlen; int max_idx = seqlen * get<2>(params.shape_LSE_partial); - - if (m_block >= cute::ceil_div(max_idx, Int{})) { return; } - - int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[maybe_virtual_batch] : get<1>(params.shape_LSE_partial); - if (num_splits <= 1) { return; } + if constexpr (Varlen) { + if (m_block * kBlockM >= max_idx) { return; } + } cutlass::FastDivmod seqlen_divmod_dynamic(seqlen); diff --git a/hopper/flash_fwd_combine_launch_template.h b/hopper/flash_fwd_combine_launch_template.h index cbaa40f34b3..ce497686045 100644 --- a/hopper/flash_fwd_combine_launch_template.h +++ b/hopper/flash_fwd_combine_launch_template.h @@ -25,7 +25,6 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool e IsEvenK, Varlen, Element, ElementPartial, ArchTag>; typename CombineKernel::Arguments args { - params.b, static_cast(params.oaccum_ptr), {!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial {params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial @@ -39,12 +38,19 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool e params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.varlen_batch_idx_ptr, params.tile_count_semaphore }; - typename CombineKernel::SchedulerArguments scheduler_args { - params.b, params.seqlen_q, params.total_q, params.h, params.h_k, params.dv, params.pack_gqa, - params.cu_seqlens_q, params.seqused_q, params.prepare_seqlen_q_ptr + typename CombineKernel::SchedulerArguments scheduler_args { + params.b, + params.seqlen_q, + params.total_q, + params.h, + params.h_k, + params.dv, + params.pack_gqa, + params.cu_seqlens_q, + params.seqused_q, + nullptr }; - - typename CombineKernel::Params kernel_params = { + typename CombineKernel::Params kernel_params { CombineKernel::to_underlying_arguments(args), CombineKernel::TileScheduler::to_underlying_arguments(scheduler_args) }; diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index e8db360b144..bd2c787d499 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -27,7 +27,7 @@ using namespace cute; template + bool PackGQA, bool Split, bool V_colmajor> void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time"); static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time"); @@ -38,7 +38,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using ElementS = cutlass::bfloat16_t; // Can't use structured binding since it's not compatible with constexpr - static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg); + static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap); static constexpr std::tuple kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV); static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS); static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS); @@ -51,17 +51,17 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using TileShape_MNK = cute::Shape, Int, Int>; using TileShape_MNK_PV = cute::Shape, Int, Int>; using ClusterShape = cute::Shape, _1, _1>; + using ElementSAux = cutlass::bfloat16_t; using CollectiveMainloop = std::conditional_t< Arch >= 90, - flash::CollectiveMainloopFwdSm90, - flash::CollectiveMainloopFwdSm80 + flash::CollectiveMainloopFwdSm90, + flash::CollectiveMainloopFwdSm80 >; - using CollectiveEpilogue = flash::CollectiveEpilogueFwd; + using CollectiveEpilogue = flash::CollectiveEpilogueFwd; static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; static constexpr bool LPT = Is_causal || Is_local; - // static constexpr bool Sort = !Is_local; - static constexpr bool Sort = false; + static constexpr bool Sort = !Is_local; using SchedulerPersistent = std::conditional_t= 90 /*WarpSpecialized*/, LPT, Sort, true /*Prepared*/>, std::conditional_t= 90 && (kHeadDim == 128 || kHeadDim == 64) && (kHeadDimV == kHeadDim) && !Is_local; - // Only needed here to decide if we should use cluster - static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg)) : 128; - static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen && !Use_one_mma_wg; - QV_SWITCH(params.qv_ptr, HasQV_, [&] { - static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256; - APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { - // Only use Cluster if number of tiles along seqlen_q is even and not varlen - CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { - static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; - int const qhead_per_khead = !PackGQA ? 1 : cutlass::ceil_div(params.h, params.h_k); - PACK_GQA_BLOCK_SWITCH(qhead_per_khead, kBlockH_, [&] { - // Non-unary values of kBlockH can improve GQA perf for specific ratios (4, 8, 16) by enabling TMA for loading Q - // Disable for hdim diff, fp16, 1 mma wg or split to shrink build - static constexpr int kBlockH = !PackGQA || Arch < 90 || (kHeadDim != kHeadDimV) || cute::is_same_v || Use_one_mma_wg || Split ? 1 : kBlockH_; - run_flash_fwd(params, stream); - }); - }); + // Only needed here to decide if we should use cluster + static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128; + static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; + BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { + static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256; + APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { + // Only use Cluster if number of tiles along seqlen_q is even and not varlen + CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { + static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; + run_flash_fwd(params, stream); }); }); }); diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index bc0a51cf0ab..3ed98b3b3e0 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -47,7 +47,7 @@ __global__ void prepare_varlen_num_blocks_kernel( int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static, cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod, int* const tile_count_semaphore, - int* const prepare_seqlen_q_ptr, + int* const num_m_blocks_ptr, int* const num_splits_dynamic_ptr, int* const varlen_batch_idx_ptr, // int* const num_n_blocks_ptr, @@ -78,7 +78,7 @@ __global__ void prepare_varlen_num_blocks_kernel( int lane = threadIdx.x % cutlass::NumThreadsPerWarp; - auto get_num_m_blocks_and_seqlen = [&](int batch_idx) { + auto get_num_m_blocks = [&](int batch_idx) { int seqlen; if (seqused_q) { seqlen = batch_idx < num_batch ? seqused_q[batch_idx] : 0; @@ -89,9 +89,9 @@ __global__ void prepare_varlen_num_blocks_kernel( } else { seqlen = seqlen_q_static; } + if(packgqa) { seqlen *= qhead_per_khead; } return batch_idx < num_batch && lane < kNumBatchPerWarp - ? cute::make_tuple(blockm_divmod.div(seqlen * (packgqa ? qhead_per_khead : 1) + blockm_divmod.divisor - 1), seqlen) - : cute::make_tuple(0, 0); + ? blockm_divmod.div(seqlen + blockm_divmod.divisor - 1) : 0; }; auto get_num_n_blocks = [&](int batch_idx) { @@ -124,10 +124,7 @@ __global__ void prepare_varlen_num_blocks_kernel( int batch_cta_idx_offset = int(blockIdx.x) * 992; int bidb_start = batch_cta_idx_offset + kNumBatchPerWarp * warp_idx; int batch_idx = lane + bidb_start; - // int num_m_blocks = get_num_m_blocks(batch_idx); - auto seqlen_q_info = get_num_m_blocks_and_seqlen(batch_idx); - int num_m_blocks = cute::get<0>(seqlen_q_info); - int seqlen_q = cute::get<1>(seqlen_q_info); + int num_m_blocks = get_num_m_blocks(batch_idx); int num_n_blocks = get_num_n_blocks(batch_idx); auto get_nheads_in_l2 = [&](int n_blocks) { @@ -166,23 +163,24 @@ __global__ void prepare_varlen_num_blocks_kernel( if constexpr (Sort) { if(lane == kNumBatchPerWarp || batch_idx >= num_batch) { num_n_blocks = INT_MIN; // sort last - } - else if (is_causal) { - // sort by middle member to process - num_n_blocks = num_n_blocks * blockn_divmod.divisor - (seqlen_q / 2); + } else if (is_causal) { // sort by shortest member to process - // num_n_blocks = num_n_blocks * blockn_divmod.divisor - seqlen_q; + num_n_blocks = num_n_blocks * blockn_divmod.divisor - num_m_blocks * blockm_divmod.divisor; } int4 batch_coords[ITEMS_PER_THREAD]; // 1 item per thread - batch_coords[0] = make_int4(num_n_blocks, seqlen_q, num_splits_dynamic, batch_idx); + batch_coords[0] = make_int4(num_n_blocks, num_m_blocks, num_splits_dynamic, batch_idx); + + // if (threadIdx.x == 0) { + // printf("Unsorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n", + // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); + // } __syncthreads(); // Sort batches by num_n_blocks in descending order BlockMergeSort(temp_storage).Sort(batch_coords, PrepareSortOp()); if (is_causal) { // reset value to num_n_blocks - batch_coords[0].x = blockn_divmod.div(batch_coords[0].x + (batch_coords[0].y / 2)); - // batch_coords[0].x = blockn_divmod.div(batch_coords[0].x + batch_coords[0].y); + batch_coords[0].x = blockn_divmod.div(batch_coords[0].x + batch_coords[0].y * blockm_divmod.divisor); } // When sorting, we re-index some metadata by 'virtual batch index' @@ -194,15 +192,16 @@ __global__ void prepare_varlen_num_blocks_kernel( batch_idx = batch_cta_idx_offset + threadIdx.x; if (batch_idx < num_batch && threadIdx.x < 992) { if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(batch_coords[0].x, 1)); } - prepare_seqlen_q_ptr[batch_idx] = batch_coords[0].y * (packgqa ? qhead_per_khead : 1); - if(num_splits_dynamic_ptr) { num_splits_dynamic_ptr[batch_idx] = batch_coords[0].z; } + num_m_blocks_ptr[batch_idx] = batch_coords[0].y; + num_splits_dynamic_ptr[batch_idx] = batch_coords[0].z; varlen_batch_idx_ptr[batch_idx] = batch_coords[0].w; } } else { if (batch_idx < num_batch && lane < kNumBatchPerWarp) { - prepare_seqlen_q_ptr[batch_idx] = seqlen_q * (packgqa ? qhead_per_khead : 1); - if(num_splits_dynamic_ptr) { num_splits_dynamic_ptr[batch_idx] = num_splits_dynamic; } + // num_n_blocks_ptr[batch_idx] = max(num_n_blocks, 1); if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(num_n_blocks, 1)); } + num_splits_dynamic_ptr[batch_idx] = num_splits_dynamic; + num_m_blocks_ptr[batch_idx] = num_m_blocks; // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); } } @@ -217,12 +216,7 @@ void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bo int num_warps = cutlass::ceil_div(params.b, 31); // warp switch will cap this at 32 int num_ctas = cutlass::ceil_div(params.b, 31 * 32); // int const size_l2 = 50 * 1024 * 1024; // 50 MB - int const size_l2_divisor = qhead_per_khead == 1 ? 1 - : qhead_per_khead <= 2 ? 2 - : qhead_per_khead <= 4 ? 4 - : qhead_per_khead <= 8 ? 8 - : 16; - int const size_l2 = (32 * 1024 * 1024) / size_l2_divisor; // experimental + int const size_l2 = 8 * 1024 * 1024; // underestimate seems better in practice int const element_size = params.is_e4m3 ? 1 : 2; int const size_one_kvblock = blockN * (params.d + params.dv) * element_size; // printf("block size = %d, element size = %d, headdim = %d, headdim_v = %d, size 1 kblock = %d.\n", blockN, element_size, params.d, params.dv, size_one_kvblock); @@ -236,7 +230,7 @@ void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bo params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits, cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN), params.tile_count_semaphore, - params.prepare_seqlen_q_ptr, + params.num_m_blocks_ptr, params.num_splits_dynamic_ptr, params.varlen_batch_idx_ptr, // params.num_n_blocks_ptr, diff --git a/hopper/mainloop_bwd_sm80.hpp b/hopper/mainloop_bwd_sm80.hpp index 0a79670f475..23baae61731 100644 --- a/hopper/mainloop_bwd_sm80.hpp +++ b/hopper/mainloop_bwd_sm80.hpp @@ -284,8 +284,10 @@ struct CollectiveMainloopBwdSm80 { ShapeQKV const shape_K; StrideQKV const stride_K; Element const* const ptr_V; + ShapeQKV const shape_V; StrideQKV const stride_V; Element const* const ptr_dO; + ShapeQKV const shape_dO; StrideQKV const stride_dO; ElementAccum* const ptr_dQaccum; ShapedQaccum const shape_dQaccum; @@ -296,7 +298,7 @@ struct CollectiveMainloopBwdSm80 { float const* const ptr_dPsum; StrideLSE const stride_dPsum; float const softmax_scale; - int const window_size_left, window_size_right; + int const window_size_left, window_size_right, attention_chunk; float const softcap_val; int const num_batch; int* const dq_semaphore; @@ -315,8 +317,10 @@ struct CollectiveMainloopBwdSm80 { ShapeQKV const shape_K; StrideQKV const stride_K; Element const* const ptr_V; + ShapeQKV const shape_V; StrideQKV const stride_V; Element const* const ptr_dO; + ShapeQKV const shape_dO; StrideQKV const stride_dO; ElementAccum* const ptr_dQaccum; ShapedQaccum const shape_dQaccum; @@ -329,6 +333,7 @@ struct CollectiveMainloopBwdSm80 { StrideLSE const stride_dPsum; float const softmax_scale, softmax_scale_log2; int const window_size_left, window_size_right; + cutlass::FastDivmod attention_chunk_divmod; float const softcap_val; int const num_batch; int *const dq_semaphore; @@ -341,6 +346,9 @@ struct CollectiveMainloopBwdSm80 { static Params to_underlying_arguments(Arguments const& args) { if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); } + // Avoid dividing by zero + cutlass::FastDivmod attention_chunk_divmod(args.attention_chunk >= 1 ? args.attention_chunk : 1); + attention_chunk_divmod.divisor = args.attention_chunk; // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. // Right after this, we multiply by log2(e) before applying exp2. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val @@ -352,14 +360,14 @@ struct CollectiveMainloopBwdSm80 { // (the original softmax_scale) at the end. return {args.ptr_Q, args.shape_Q, args.stride_Q, args.ptr_K, args.shape_K, args.stride_K, - args.ptr_V, args.stride_V, - args.ptr_dO, args.stride_dO, + args.ptr_V, args.shape_V, args.stride_V, + args.ptr_dO, args.shape_dO, args.stride_dO, args.ptr_dQaccum, args.shape_dQaccum, args.stride_dQaccum, cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, args.softmax_scale, !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), - args.window_size_left, args.window_size_right, + args.window_size_left, args.window_size_right, attention_chunk_divmod, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k}; @@ -413,9 +421,9 @@ struct CollectiveMainloopBwdSm80 { bool const is_varlen_k = Varlen && params.cu_seqlens_k; int bidh_kv = params.qhead_per_khead_divmod.divide(bidh); Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q), params.shape_Q, params.stride_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_Q, params.stride_dO)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_dO, params.stride_dO)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); - Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); + Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_V, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, bidh, !is_varlen_q ? bidb : 0); Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, bidh, !is_varlen_q ? bidb : 0); Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), @@ -527,13 +535,16 @@ struct CollectiveMainloopBwdSm80 { for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < get<1>(params.shape_Q); } Tensor cLSE = cute::make_identity_tensor(select<0>(TileShape_MNK{})); Tensor tLSEcLSE = gmem_thr_copy_lse.partition_S(cLSE); + Tensor tdOpdO = make_tensor(make_shape(size<2>(tdOsdO))); + #pragma unroll + for (int k = 0; k < size(tdOpdO); ++k) { tdOpdO(k) = get<1>(tQcQ(_0{}, _0{}, k)) < get<1>(params.shape_dO); } int const seqlen_q = seqlen_info.seqlen_q; int const seqlen_k = seqlen_info.seqlen_k; flash::Mask mask( thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, - params.qhead_per_khead_divmod + params.attention_chunk_divmod, params.qhead_per_khead_divmod ); { @@ -545,9 +556,12 @@ struct CollectiveMainloopBwdSm80 { Tensor cKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); Tensor t0KVcKV = gmem_thr0_copy_QKV.partition_S(cKV); - Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); + Tensor tVpV = make_tensor(make_shape(size<2>(tVsV))); + #pragma unroll + for (int k = 0; k < size(tKpK); ++k) { tKpK(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_K); } #pragma unroll - for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_K); } + for (int k = 0; k < size(tVpV); ++k) { tVpV(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_V); } // Do we need bound check to make sure the row doesn't go above kBlockN static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; // static_assert(EvenN); // It simplifies the loading of K and V @@ -567,7 +581,7 @@ struct CollectiveMainloopBwdSm80 { bool const predicate_n = get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit; #pragma unroll for (int k = 0; k < size<2>(tVsV); ++k) { - cute::copy(gmem_tiled_copy_QKV.with(tKVpKV(k) && predicate_n), tVgV(_, m, k), tVsV(_, m, k)); + cute::copy(gmem_tiled_copy_QKV.with(tVpV(k) && predicate_n), tVgV(_, m, k), tVsV(_, m, k)); } } } @@ -580,7 +594,7 @@ struct CollectiveMainloopBwdSm80 { bool const predicate_n = get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit; #pragma unroll for (int k = 0; k < size<2>(tKsK); ++k) { - cute::copy(gmem_tiled_copy_QKV.with(tKVpKV(k) && predicate_n), tKgK(_, m, k), tKsK(_, m, k)); + cute::copy(gmem_tiled_copy_QKV.with(tKpK(k) && predicate_n), tKgK(_, m, k), tKsK(_, m, k)); } } } @@ -653,7 +667,7 @@ struct CollectiveMainloopBwdSm80 { bool const predicate_m = get<0>(t0QcQ(_0{}, m, _0{})) < seqlenq_row_limit; #pragma unroll for (int k = 0; k < size<2>(tdOsdO); ++k) { - cute::copy(gmem_tiled_copy_QKV.with(tQpQ(k) && predicate_m), tdOgdO_cur(_, m, k), tdOsdO_cur(_, m, k)); + cute::copy(gmem_tiled_copy_QKV.with(tdOpdO(k) && predicate_m), tdOgdO_cur(_, m, k), tdOsdO_cur(_, m, k)); } } } @@ -817,21 +831,21 @@ struct CollectiveMainloopBwdSm80 { // if (cute::thread0()) { print_tensor(tdVrdV); } __syncthreads(); // make sure sdS is written auto do_mma_dQ = [&] (auto hook) { - Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); - clear(tdQrdQ); - Tensor tdQrdS = mma_partition_fragment_AB(thr_mma_dQ, sdS); - Tensor tdQrK = mma_partition_fragment_AB(thr_mma_dQ, sKt); - flash::gemm_sm80( - tdQrdQ, tdQrdS, tdQrK, tdQsdS, tdQsKt, tiled_mma_dQ, - // smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt, load_dO_next); - smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt, hook); - // if (cute::thread0()) { print_tensor(tdQrdQ); } - // We can reuse r2s_thr_copy_dQaccum for this partitioning - Tensor tdQrdQ_atomic = r2s_thr_copy_dQaccum.retile_S(tdQrdQ); - Tensor tdQgdQaccum_atomic = tdQgdQaccum(_, _, m_block); - static_assert(CUTE_STATIC_V(size(tdQrdQ_atomic)) == CUTE_STATIC_V(size(tdQgdQaccum_atomic))); - #pragma unroll - for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } + Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); + clear(tdQrdQ); + Tensor tdQrdS = mma_partition_fragment_AB(thr_mma_dQ, sdS); + Tensor tdQrK = mma_partition_fragment_AB(thr_mma_dQ, sKt); + flash::gemm_sm80( + tdQrdQ, tdQrdS, tdQrK, tdQsdS, tdQsKt, tiled_mma_dQ, + // smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt, load_dO_next); + smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt, hook); + // if (cute::thread0()) { print_tensor(tdQrdQ); } + // We can reuse r2s_thr_copy_dQaccum for this partitioning + Tensor tdQrdQ_atomic = r2s_thr_copy_dQaccum.retile_S(tdQrdQ); + Tensor tdQgdQaccum_atomic = tdQgdQaccum(_, _, m_block); + static_assert(CUTE_STATIC_V(size(tdQrdQ_atomic)) == CUTE_STATIC_V(size(tdQgdQaccum_atomic))); + #pragma unroll + for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } }; // If kStages == 1, we want to do Mma_dK first so we can start loading Q for the next iteration if constexpr (kStages > 1) { do_mma_dQ(load_dO_next); } diff --git a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp index 71cfb020469..c67ae17969f 100644 --- a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -298,8 +298,10 @@ struct CollectiveMainloopBwdSm90 { ShapeQKV const shape_K; StrideQKV const stride_K; Element const* const ptr_V; + ShapeQKV const shape_V; StrideQKV const stride_V; Element const* const ptr_dO; + ShapeQKV const shape_dO; StrideQKV const stride_dO; ElementAccum* const ptr_dQaccum; ShapedQaccum const shape_dQaccum; @@ -310,7 +312,7 @@ struct CollectiveMainloopBwdSm90 { float const* const ptr_dPsum; StrideLSE const stride_dPsum; float const softmax_scale; - int const window_size_left, window_size_right; + int const window_size_left, window_size_right, attention_chunk; float const softcap_val; int const num_batch; int* const dq_semaphore; @@ -324,6 +326,8 @@ struct CollectiveMainloopBwdSm90 { struct Params { ShapeQKV const shape_Q; ShapeQKV const shape_K; + ShapeQKV const shape_V; + ShapeQKV const shape_dO; ElementAccum* const ptr_dQaccum; ShapedQaccum const shape_dQaccum; StridedQaccum stride_dQaccum; @@ -338,6 +342,7 @@ struct CollectiveMainloopBwdSm90 { StrideLSE const stride_dPsum; float const softmax_scale, softmax_scale_log2; int const window_size_left, window_size_right; + cutlass::FastDivmod attention_chunk_divmod; float const softcap_val; int const num_batch; int* const dq_semaphore; @@ -356,7 +361,7 @@ struct CollectiveMainloopBwdSm90 { SmemLayoutQ{}(_, _, _0{}), TileShape_MNK{}, ClusterShape{}); // mcast along N mode for this M load, if any - Tensor mdO = make_tensor(make_gmem_ptr(args.ptr_dO), args.shape_Q, args.stride_dO); + Tensor mdO = make_tensor(make_gmem_ptr(args.ptr_dO), args.shape_dO, args.stride_dO); TMA_QdO tma_load_dO = make_tma_copy_A_sm90( GmemTiledCopyQdO{}, mdO, @@ -370,7 +375,7 @@ struct CollectiveMainloopBwdSm90 { SmemLayoutK{}, TileShape_MNK{}, ClusterShape{}); // no mcast for KV - Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_K, args.stride_V); + Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_V, args.stride_V); TMA_V tma_load_V = make_tma_copy_B_sm90( GmemTiledCopyKV{}, mV, @@ -378,6 +383,9 @@ struct CollectiveMainloopBwdSm90 { TileShape_MNK{}, ClusterShape{}); // no mcast for KV if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); } + // Avoid dividing by zero + cutlass::FastDivmod attention_chunk_divmod(args.attention_chunk >= 1 ? args.attention_chunk : 1); + attention_chunk_divmod.divisor = args.attention_chunk; // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. // Right after this, we multiply by log2(e) before applying exp2. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val @@ -387,14 +395,15 @@ struct CollectiveMainloopBwdSm90 { // (1 - tanh^2) * softmax_scale / softcap_val * softcap_val = (1 - tanh^2) * softmax_scale. // Instead we multiply by (1 - tanh^2) and multiply dK and dV by params.softmax_scale // (the original softmax_scale) at the end. - return {args.shape_Q, args.shape_K, + return {args.shape_Q, args.shape_K, + args.shape_V, args.shape_dO, args.ptr_dQaccum, args.shape_dQaccum, args.stride_dQaccum, cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), tma_load_Q, tma_load_dO, tma_load_K, tma_load_V, args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, args.softmax_scale, !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), - args.window_size_left, args.window_size_right, + args.window_size_left, args.window_size_right, attention_chunk_divmod, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k}; @@ -453,9 +462,9 @@ struct CollectiveMainloopBwdSm90 { bool const is_varlen_q = Varlen && params.cu_seqlens_q; bool const is_varlen_k = Varlen && params.cu_seqlens_k; Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor mdO = params.tma_load_dO.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor mdO = params.tma_load_dO.get_tma_tensor(params.shape_dO)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mK = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); - Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); + Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_V)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, bidh, !is_varlen_q ? bidb : 0); Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, bidh, !is_varlen_q ? bidb : 0); @@ -598,7 +607,8 @@ struct CollectiveMainloopBwdSm90 { seqlen_info, n_block, bidb, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/); // It's possible to have m_block_max <= m_block_min. Exit early - if constexpr (Is_causal || Is_local || Varlen) { + // Though if local and deterministic, still need to increment dq semaphore + if constexpr ((Is_causal || Is_local || Varlen) && !(Is_local && Deterministic)) { if (m_block_max <= m_block_min) { return; } } @@ -617,10 +627,18 @@ struct CollectiveMainloopBwdSm90 { using Barrier = cutlass::GenericBarrier; bool const lane_predicate = cute::elect_one_sync(); int m_block = m_block_min; + constexpr int kBlockM = get<0>(TileShape_MNK{}); + constexpr int kBlockN = get<1>(TileShape_MNK{}); + int n_block_global_max = cute::ceil_div(seqlen_info.seqlen_k, kBlockN); #pragma unroll 2 for (; m_block < m_block_max; ++m_block) { if constexpr (Deterministic) { - Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block); + if constexpr(Is_causal) { + int n_block_max_for_m_block = std::min(n_block_global_max, cute::ceil_div((m_block + 1) * kBlockM + seqlen_info.seqlen_k - seqlen_info.seqlen_q, kBlockN)); + Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block_max_for_m_block - 1 - n_block); + } else { + Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block); + } } #pragma unroll for (int warpgroup_idx = 0; warpgroup_idx < NumMmaWarpGroups; ++warpgroup_idx) { @@ -640,7 +658,6 @@ struct CollectiveMainloopBwdSm90 { } } if constexpr (Is_local && Deterministic) { - constexpr int kBlockM = get<0>(TileShape_MNK{}); int const m_block_global_max = cute::ceil_div(seqlen_info.seqlen_q, kBlockM); #pragma unroll 2 for (; m_block < m_block_global_max; ++m_block) { @@ -793,7 +810,7 @@ struct CollectiveMainloopBwdSm90 { flash::Mask mask( thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, - params.qhead_per_khead_divmod + params.attention_chunk_divmod, params.qhead_per_khead_divmod ); int m_block = m_block_min; @@ -921,7 +938,7 @@ struct CollectiveMainloopBwdSm90 { Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); Tensor tdQrdS_cur = tdQrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); flash::gemm(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ); - pipeline_do.consumer_release(smem_pipe_read_do_cur); // release dQ + pipeline_do.consumer_release(smem_pipe_read_do_cur); // release dO if constexpr (Mma_dKV_is_RS) { Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); diff --git a/hopper/mainloop_fwd_sm80.hpp b/hopper/mainloop_fwd_sm80.hpp index a7c42fe0ffb..cf30b4036fc 100644 --- a/hopper/mainloop_fwd_sm80.hpp +++ b/hopper/mainloop_fwd_sm80.hpp @@ -203,7 +203,7 @@ struct CollectiveMainloopFwdSm80 { float const softmax_scale; float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; - int const window_size_left = -1, window_size_right = -1; + int const window_size_left = -1, window_size_right = -1, attention_chunk = 0; float const softcap_val; int const num_splits; int const* const kv_batch_idx = nullptr; @@ -254,6 +254,7 @@ struct CollectiveMainloopFwdSm80 { StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; float const softcap_val; int const window_size_left, window_size_right; + cutlass::FastDivmod attention_chunk_divmod; int const num_splits; int const* const kv_batch_idx = nullptr; int const* const cu_seqlens_q = nullptr; @@ -282,6 +283,9 @@ struct CollectiveMainloopFwdSm80 { assert(args.ptr_rotary_cos != nullptr && args.ptr_rotary_sin != nullptr); } assert(args.num_splits >= 1); + // Avoid dividing by zero + cutlass::FastDivmod attention_chunk_divmod(args.attention_chunk >= 1 ? args.attention_chunk : 1); + attention_chunk_divmod.divisor = args.attention_chunk; // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. // Right after this, we multiply by log2(e) before applying exp2. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val @@ -299,7 +303,7 @@ struct CollectiveMainloopFwdSm80 { args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale, args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, - args.window_size_left, args.window_size_right, + args.window_size_left, args.window_size_right, attention_chunk_divmod, !Split ? 1 : args.num_splits, args.kv_batch_idx, args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, @@ -329,7 +333,8 @@ struct CollectiveMainloopFwdSm80 { int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; auto n_block_min_max = BlockMN_t::get_n_block_min_max( seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); int const n_block_min = get<0>(n_block_min_max); int const n_block_max = get<1>(n_block_min_max); int const n_offset = get<2>(n_block_min_max); @@ -555,7 +560,7 @@ struct CollectiveMainloopFwdSm80 { flash::Mask mask( thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, - params.qhead_per_khead_divmod + params.attention_chunk_divmod, params.qhead_per_khead_divmod ); float softcap_val = params.softcap_val; @@ -628,19 +633,17 @@ struct CollectiveMainloopFwdSm80 { --n_block; if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; - int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM); - int const n_block_min_causal_local_mask = - std::max(n_block_min, (m_idx_min + seqlen_k - seqlen_q + params.window_size_right) / kBlockN); + int const n_block_min_causal_local_mask = BlockMN_t::get_n_block_min_causal_local_mask( + seqlen_info, m_block, n_block_min, params.window_size_right, + params.attention_chunk_divmod, params.qhead_per_khead_divmod); #pragma unroll 1 for (; n_block >= n_block_min_causal_local_mask; --n_block) { fwd_step(n_block, mask_fn, cute::false_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/); } } - int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1; - int const n_block_min_before_local_mask = !Is_local - ? n_block_min - : std::max(n_block_min, - cute::ceil_div(m_idx_max + seqlen_k - seqlen_q - params.window_size_left, kBlockN)); + int const n_block_min_before_local_mask = BlockMN_t::get_n_block_min_before_local_mask( + seqlen_info, m_block, n_block_min, params.window_size_left, + params.attention_chunk_divmod, params.qhead_per_khead_divmod); auto no_mask_fn = [](auto& tSrS, int n_block) { }; #pragma unroll 1 for (; n_block >= n_block_min_before_local_mask; --n_block) { @@ -672,7 +675,8 @@ struct CollectiveMainloopFwdSm80 { auto [m_block, bidh, bidb, split_idx] = block_coord; auto n_block_new_min_max = BlockMN_t::get_n_block_k_new_min_max( seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); int const n_block_new_min = get<0>(n_block_new_min_max); int const n_block_new_max = get<1>(n_block_new_min_max); if (n_block_new_max <= n_block_new_min) { return false; } diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 6e0d8b768b7..9147cf7c4a6 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -55,7 +55,7 @@ struct CollectiveMainloopFwdSm90 { static constexpr bool Split = Split_; static constexpr bool V_colmajor = V_colmajor_; static constexpr bool Transpose_V = Is_FP8 && !V_colmajor; - static constexpr bool Use_TMA_Q = !PackGQA || PackGQA_TMA; + static constexpr bool Use_TMA_Q = !PackGQA; static constexpr bool Use_TMA_KV = !PagedKVNonTMA; static_assert(Use_TMA_KV || CUTE_STATIC_V(size(ClusterShape{})) == 1, "If not using TMA for KV, ClusterShape must be 1"); static_assert(Use_TMA_KV || !V_colmajor, "If not using TMA for KV, V_colmajor is not supported"); @@ -400,7 +400,7 @@ struct CollectiveMainloopFwdSm90 { float const softmax_scale; float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; - int const window_size_left = -1, window_size_right = -1; + int const window_size_left = -1, window_size_right = -1, attention_chunk = 0; float const softcap_val; int const num_splits; int const* const kv_batch_idx = nullptr; @@ -424,6 +424,7 @@ struct CollectiveMainloopFwdSm90 { ShapeQKV const shape_Q; StrideQK const stride_Q; ShapeQPacked const shape_Q_packed; + ShapeQPackedTMA const shape_Q_packed_tma; StrideQPacked const stride_Q_packed; Element* const ptr_K; ShapeQKV const shape_K; @@ -463,6 +464,7 @@ struct CollectiveMainloopFwdSm90 { StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; float const softcap_val; int const window_size_left, window_size_right; + cutlass::FastDivmod attention_chunk_divmod; int const num_splits; int const* const kv_batch_idx = nullptr; int const* const cu_seqlens_q = nullptr; @@ -473,8 +475,8 @@ struct CollectiveMainloopFwdSm90 { int const* const leftpad_k = nullptr; int const* const seqlens_rotary = nullptr; ElementSAux const* const ptr_S_aux = nullptr; - int cp_world_size = 1; - int cp_rank = 0; + int const cp_world_size = 1; + int const cp_rank = 0; int const* const cp_tot_seqused_k = nullptr; }; @@ -564,14 +566,16 @@ struct CollectiveMainloopFwdSm90 { if (!PagedKVNonTMA && args.ptr_pagetable != nullptr) { assert(page_size % kBlockN == 0); assert(!args.leftpad_k); - assert(!Is_local); // Since we now use leftpad_k with local, we can't use TMA with PagedKV } + // Avoid dividing by zero + cutlass::FastDivmod attention_chunk_divmod(args.attention_chunk >= 1 ? args.attention_chunk : 1); + attention_chunk_divmod.divisor = args.attention_chunk; // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. // Right after this, we multiply by log2(e) before applying exp2. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) // (assigning it to params.softmax_scale_log2). - return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed, + return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, shape_Q_packed_tma, stride_Q_packed, args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.headdim_v, args.stride_V, args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new, args.ptr_Qv, args.stride_Qv, shape_Qv_packed, stride_Qv_packed, @@ -586,7 +590,7 @@ struct CollectiveMainloopFwdSm90 { args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale, args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, - args.window_size_left, args.window_size_right, + args.window_size_left, args.window_size_right, attention_chunk_divmod, !Split ? 1 : args.num_splits, args.kv_batch_idx, args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, @@ -639,7 +643,9 @@ struct CollectiveMainloopFwdSm90 { // seqlen_k -> seqlen_k - n_offset auto [n_block_min, n_block_max, n_offset] = BlockMN_t::get_n_block_min_max( seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); + (void)n_offset; // It's possible to have n_block_max <= n_block_min. Loading K can cause illegal memory access. if constexpr (Is_causal || Is_local || Varlen || Split) { if (n_block_max <= n_block_min) { @@ -683,20 +689,12 @@ struct CollectiveMainloopFwdSm90 { bool const is_varlen_q = Varlen && params.cu_seqlens_q; bool const is_varlen_k = Varlen && params.cu_seqlens_k; - Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q_packed_tma)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mK_TMA = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, _); auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); Tensor mVt_TMA = params.tma_load_V.get_tma_tensor(shape_V)(_, _, bidh_kv, _); - - Tensor gQ = local_tile( - domain_offset( - cute::conditional_return( - make_coord(seqlen_info.offset_q, _0{}), - make_coord(make_coord(_0{}, seqlen_info.offset_q), _0{})), - mQ), - select<0, 2>(TileShape_MNK{}), - make_coord(m_block, _0{})); // (M, K) // if (cute::thread0()) { printf("Varlen = %d, params.leftpad_k = %p, leftpad_k = %d\n", Varlen, params.leftpad_k, leftpad_k); } + Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) // Now add n_offset to update KV gmem pointers Tensor gK_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k + n_offset, _0{}, _0{}), mK_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}, _)); // (N, K, _, _) Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k + n_offset, _0{}), mVt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _, _)); // (K, N, _, _) @@ -704,6 +702,7 @@ struct CollectiveMainloopFwdSm90 { auto block_tma_Q = params.tma_load_Q.get_slice(_0{}); Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA) Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); // (TMA) + if (Use_TMA_Q && thread_idx == 0) { prefetch(params.tma_load_Q, tQgQ); } // tma_partition doesn't handle position_independent_swizzle_tensor correctly, so we need to do it manually auto block_tma_K = params.tma_load_K.get_slice(cluster_local_block_id.x); Tensor tKgK_TMA = group_modes<0, 3>(block_tma_K.partition_S(gK_TMA)); // (TMA, k, batch) @@ -1014,7 +1013,8 @@ struct CollectiveMainloopFwdSm90 { int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; auto [n_block_min, n_block_max, n_offset] = BlockMN_t::get_n_block_min_max( seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier if constexpr (Is_causal || Is_local || Varlen || Split) { if (n_block_max <= n_block_min) { return false; } @@ -1101,7 +1101,7 @@ struct CollectiveMainloopFwdSm90 { // But we subtract n_offset for consistency in mask calculations flash::Mask mask( thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 - n_offset /*sink_token_length*/, - params.qhead_per_khead_divmod, + params.attention_chunk_divmod, params.qhead_per_khead_divmod, params.cp_world_size, params.cp_rank, seqlen_info.tot_seqlen_k ); @@ -1282,32 +1282,18 @@ struct CollectiveMainloopFwdSm90 { if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; - int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM); - // If local, blocking (window_size_right + window_size_left) - // when cp is not enabled, tot_seqlen_k is equal to seqlen_k, and cp_world_size is 1. - // cp_world_size is guaranteed to be greater than 0 - int const n_block_min_causal_local_mask = - std::max(n_block_min, - (m_idx_min + seqlen_info.tot_seqlen_k - seqlen_q + params.window_size_right) / - seqlen_info.cp_world_size / - kBlockN); + int const n_block_min_causal_local_mask = BlockMN_t::get_n_block_min_causal_local_mask( + seqlen_info, m_block, n_block_min, params.window_size_right, + params.attention_chunk_divmod, params.qhead_per_khead_divmod); #pragma unroll 1 for (; n_block >= n_block_min_causal_local_mask; --n_block) { fwd_step(n_block, mask_fn, cute::true_type{} /*check_inf*/); } } - int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1; - // If local, blocking (m_idx_max - m_idx_min) - // when cp is not enabled, tot_seqlen_k is equal to seqlen_k, and cp_world_size is 1. - // cp_world_size is guaranteed to be greater than 0 - int const n_block_min_before_local_mask = !Is_local - ? n_block_min - : std::max(n_block_min, - cute::ceil_div( - cute::ceil_div(m_idx_max + seqlen_info.tot_seqlen_k - seqlen_q - params.window_size_left - seqlen_info.cp_rank, - seqlen_info.cp_world_size), - kBlockN)); + int const n_block_min_before_local_mask = BlockMN_t::get_n_block_min_before_local_mask( + seqlen_info, m_block, n_block_min, params.window_size_left, + params.attention_chunk_divmod, params.qhead_per_khead_divmod); auto no_mask_fn = [](auto& tSrS, int n_block) { }; #pragma unroll 1 for (; n_block >= n_block_min_before_local_mask; --n_block) { @@ -1398,21 +1384,17 @@ struct CollectiveMainloopFwdSm90 { --n_block; if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; - int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM); - // If local, blocking (window_size_right + window_size_left) - int const n_block_min_causal_local_mask = - std::max(n_block_min, (m_idx_min + seqlen_k - seqlen_q + params.window_size_right) / kBlockN); + int const n_block_min_causal_local_mask = BlockMN_t::get_n_block_min_causal_local_mask( + seqlen_info, m_block, n_block_min, params.window_size_right, + params.attention_chunk_divmod, params.qhead_per_khead_divmod); #pragma unroll 1 for (; n_block >= n_block_min_causal_local_mask; --n_block) { fwd_step(n_block, mask_fn, cute::false_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/); } } - int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1; - // If local, blocking (m_idx_max - m_idx_min) - int const n_block_min_before_local_mask = !Is_local - ? n_block_min - : std::max(n_block_min, - cute::ceil_div(m_idx_max + seqlen_k - seqlen_q - params.window_size_left, kBlockN)); + int const n_block_min_before_local_mask = BlockMN_t::get_n_block_min_before_local_mask( + seqlen_info, m_block, n_block_min, params.window_size_left, + params.attention_chunk_divmod, params.qhead_per_khead_divmod); auto no_mask_fn = [](auto& tSrS, int n_block) { }; #pragma unroll 1 for (; n_block >= n_block_min_before_local_mask; --n_block) { @@ -1466,7 +1448,8 @@ struct CollectiveMainloopFwdSm90 { int const split_idx = get<3>(block_coord); auto [n_block_min, n_block_max, n_offset] = BlockMN_t::get_n_block_min_max( seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier if constexpr (Is_causal || Is_local || Varlen || Split) { if (n_block_max <= n_block_min) { return false; } @@ -1552,7 +1535,8 @@ struct CollectiveMainloopFwdSm90 { auto [m_block, bidh, bidb, split_idx] = block_coord; auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max( seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); if (n_block_new_max <= n_block_new_min) { return false; } @@ -1654,7 +1638,8 @@ struct CollectiveMainloopFwdSm90 { auto [m_block, bidh, bidb, split_idx] = block_coord; auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max( seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); if (n_block_new_max <= n_block_new_min) { return false; } // as_position_independent_swizzle_tensor makes address calculation easier diff --git a/hopper/mask.h b/hopper/mask.h index c3cba012505..ddc9d6908a7 100644 --- a/hopper/mask.h +++ b/hopper/mask.h @@ -22,12 +22,14 @@ struct Mask { int const thread_idx; int const seqlen_q, seqlen_k; int const window_size_left, window_size_right, sink_token_length; + cutlass::FastDivmod const attention_chunk_divmod; cutlass::FastDivmod const qhead_per_khead_divmod; int const cp_world_size, cp_rank, tot_seqlen_k; CUTLASS_DEVICE Mask(const int thread_idx, const int seqlen_q, const int seqlen_k, const int window_size_left, const int window_size_right, const int sink_token_length, + cutlass::FastDivmod const &attention_chunk_divmod, cutlass::FastDivmod const &qhead_per_khead_divmod, const int cp_world_size = 1, const int cp_rank = 0, const int tot_seqlen_k = 0) : thread_idx(thread_idx) @@ -36,6 +38,7 @@ struct Mask { , window_size_left(window_size_left) , window_size_right(window_size_right) , sink_token_length(sink_token_length) + , attention_chunk_divmod(attention_chunk_divmod) , qhead_per_khead_divmod(qhead_per_khead_divmod) , cp_world_size(cp_world_size) , cp_rank(cp_rank) @@ -117,16 +120,21 @@ struct Mask { } else { int const local_row_offset_right = causal_row_offset + window_size_right; int const local_row_offset_left = causal_row_offset - 1 - window_size_left; - int const col_limit_sink = sink_token_length - n_block * kBlockN; + int const col_limit_sink = sink_token_length - n_block * kBlockN; // TODO: subtract thread_col_offset? #pragma unroll for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { int const row_idx = !PackGQA ? get(tScS_rowcol(m, _0{})) + m_block * kBlockM : __shfl_sync(0xffffffff, mma_m_idx, m % kMmaThreadsPerRow, kMmaThreadsPerRow); - int const col_limit_right = !Seqlenk_mask + int col_limit_right = !Seqlenk_mask ? row_idx + local_row_offset_right : __viaddmin_s32(row_idx, local_row_offset_right, seqlenk_col_limit); - int const col_limit_left = row_idx + local_row_offset_left; + int col_limit_left = row_idx + local_row_offset_left; + if (attention_chunk_divmod.divisor > 0) { + int col_limit_left_chunk = flash::round_down(attention_chunk_divmod, row_idx + seqlen_k - seqlen_q) - n_block * kBlockN - thread_col_offset; + col_limit_left = std::max(col_limit_left, col_limit_left_chunk); + col_limit_right = std::min(col_limit_right, col_limit_left_chunk + attention_chunk_divmod.divisor); + } #pragma unroll for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { int const col_idx = int(get(t0ScS_rowcol(m, n))); diff --git a/hopper/setup.py b/hopper/setup.py index 887b6339023..38748d30804 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -68,31 +68,6 @@ DISABLE_HDIMDIFF64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF64", "FALSE") == "TRUE" DISABLE_HDIMDIFF192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF192", "FALSE") == "TRUE" -PACKGQA_ONLY = os.getenv("FLASH_ATTENTION_PACKGQA_ONLY", "FALSE") == "TRUE" - -DISABLE_BACKWARD = True -# DISABLE_SPLIT = True -# DISABLE_PAGEDKV = True -# DISABLE_APPENDKV = True -# DISABLE_LOCAL = True -# DISABLE_SOFTCAP = True -# DISABLE_PACKGQA = True -# DISABLE_FP16 = True -# DISABLE_FP8 = True -# DISABLE_VARLEN = True -DISABLE_CLUSTER = True -# DISABLE_HDIM64 = True -# DISABLE_HDIM96 = True -# DISABLE_HDIM128 = True -# DISABLE_HDIM192 = True -# DISABLE_HDIM256 = True -DISABLE_SM8x = True - -DISABLE_HDIMDIFF64 = True -# DISABLE_HDIMDIFF192 = True - -PACKGQA_ONLY = True - # HACK: we monkey patch pytorch's _write_ninja_file to pass # "-gencode arch=compute_sm90a,code=sm_90a" to files ending in '_sm90.cu', # and pass "-gencode arch=compute_sm80,code=sm_80" to files ending in '_sm80.cu' @@ -108,6 +83,42 @@ _maybe_write, ) +def create_build_config_file(): + CONFIG = { + "build_flags": { + "FLASHATTENTION_DISABLE_BACKWARD": DISABLE_BACKWARD, + "FLASHATTENTION_DISABLE_SPLIT": DISABLE_SPLIT, + "FLASHATTENTION_DISABLE_PAGEDKV": DISABLE_PAGEDKV, + "FLASHATTENTION_DISABLE_APPENDKV": DISABLE_APPENDKV, + "FLASHATTENTION_DISABLE_LOCAL": DISABLE_LOCAL, + "FLASHATTENTION_DISABLE_SOFTCAP": DISABLE_SOFTCAP, + "FLASHATTENTION_DISABLE_PACKGQA": DISABLE_PACKGQA, + "FLASHATTENTION_DISABLE_FP16": DISABLE_FP16, + "FLASHATTENTION_DISABLE_FP8": DISABLE_FP8, + "FLASHATTENTION_DISABLE_VARLEN": DISABLE_VARLEN, + "FLASHATTENTION_DISABLE_CLUSTER": DISABLE_CLUSTER, + "FLASHATTENTION_DISABLE_HDIM64": DISABLE_HDIM64, + "FLASHATTENTION_DISABLE_HDIM96": DISABLE_HDIM96, + "FLASHATTENTION_DISABLE_HDIM128": DISABLE_HDIM128, + "FLASHATTENTION_DISABLE_HDIM192": DISABLE_HDIM192, + "FLASHATTENTION_DISABLE_HDIM256": DISABLE_HDIM256, + "FLASHATTENTION_DISABLE_SM8x": DISABLE_SM8x, + "FLASHATTENTION_ENABLE_VCOLMAJOR": ENABLE_VCOLMAJOR, + "FLASH_ATTENTION_DISABLE_HDIMDIFF64": DISABLE_HDIMDIFF64, + "FLASH_ATTENTION_DISABLE_HDIMDIFF192": DISABLE_HDIMDIFF192, + } + } + + with open("flash_attn_config.py", "w") as f: + f.write("# Auto-generated by flash attention 3 setup.py\n") + f.write(f"CONFIG = {repr(CONFIG)}\n") + f.write("\n") + + f.write("def show():\n") + f.write(" from pprint import pprint\n") + f.write(" pprint(CONFIG)\n") + f.write("\n") + def _write_ninja_file(path, cflags, post_cflags, @@ -420,15 +431,23 @@ def nvcc_threads_args(): TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) + create_build_config_file() check_if_cuda_home_none(PACKAGE_NAME) _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version < Version("12.3"): raise RuntimeError("FlashAttention-3 is only supported on CUDA 12.3 and above") + elif bare_metal_version >= Version("13.0"): + # CUDA 13.0+ uses system nvcc and CCCL headers are in /usr/local/cuda/include/cccl/ + cccl_include = os.path.join(CUDA_HOME, "include", "cccl") + for env_var in ["CPLUS_INCLUDE_PATH", "C_INCLUDE_PATH"]: + current = os.environ.get(env_var, "") + os.environ[env_var] = cccl_include + (":" + current if current else "") # ptxas 12.8 gives the best perf currently # We want to use the nvcc front end from 12.6 however, since if we use nvcc 12.8 # Cutlass 3.8 will expect the new data types in cuda.h from CTK 12.8, which we don't have. - if bare_metal_version != Version("12.8"): + # For CUDA 13.0+, use system nvcc instead of downloading CUDA 12.x toolchain + if bare_metal_version >= Version("12.3") and bare_metal_version < Version("13.0") and bare_metal_version != Version("12.8"): download_and_copy( name="nvcc", src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin", @@ -454,7 +473,7 @@ def nvcc_threads_args(): f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", ) base_dir = os.path.dirname(__file__) - ctk_path_new = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", "bin") + ctk_path_new = os.path.abspath(os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", "bin")) nvcc_path_new = os.path.join(ctk_path_new, f"nvcc{exe_extension}") # Need to append to path otherwise nvcc can't find cicc in nvvm/bin/cicc # nvcc 12.8 seems to hard-code looking for cicc in ../nvvm/bin/cicc @@ -498,8 +517,6 @@ def nvcc_threads_args(): + (["-DFLASHATTENTION_ENABLE_VCOLMAJOR"] if ENABLE_VCOLMAJOR else []) + (["-DFLASHATTENTION_DISABLE_HDIMDIFF64"] if DISABLE_HDIMDIFF64 else []) + (["-DFLASHATTENTION_DISABLE_HDIMDIFF192"] if DISABLE_HDIMDIFF192 else []) - + (["-DFLASHATTENTION_PACKGQA_ONLY"] if PACKGQA_ONLY else []) - + (["-DFLASHATTENTION_VARLEN_ONLY"]) ) DTYPE_FWD_SM80 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) @@ -515,18 +532,7 @@ def nvcc_threads_args(): + ([256] if not DISABLE_HDIM256 else []) ) # HEAD_DIMENSIONS_FWD = ["all", "diff"] - # HEAD_DIMENSIONS_FWD = ( - # ["all"] - # + (["diff"] if not DISABLE_HDIMDIFF else []) - # ) - HEAD_DIMENSIONS_FWD = ( - [] - + ([64] if not DISABLE_HDIM64 else []) - + ([96] if not DISABLE_HDIM96 else []) - + ([128] if not DISABLE_HDIM128 else []) - + ([192] if not DISABLE_HDIM192 else []) - + ([256] if not DISABLE_HDIM256 else []) - ) + HEAD_DIMENSIONS_FWD = HEAD_DIMENSIONS_BWD HEAD_DIMENSIONS_DIFF64_FWD = ( [] + (["64_256"] if not DISABLE_HDIMDIFF64 else []) @@ -550,10 +556,7 @@ def nvcc_threads_args(): # We then add "(packgqa or paged or split)" to enable PackGQA always sources_fwd_sm90 = [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_FWD, DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) - if not (packgqa and (paged or split)) and (not PACKGQA_ONLY or packgqa or paged or split)] - # sources_fwd_sm90 = [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" - # for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_FWD, DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) - # if not (packgqa and (paged or split))] + if not (packgqa and (paged or split))] if not DISABLE_HDIMDIFF64: sources_fwd_sm90 += [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_DIFF64_FWD, HALF_DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) @@ -569,8 +572,20 @@ def nvcc_threads_args(): if DISABLE_BACKWARD: sources_bwd_sm90 = [] sources_bwd_sm80 = [] + + # Choose between flash_api.cpp and flash_api_stable.cpp based on torch version + torch_version = parse(torch.__version__) + target_version = parse("2.9.0.dev20250830") + stable_args = [] + + if torch_version >= target_version: + flash_api_source = "flash_api_stable.cpp" + stable_args = ["-DTORCH_STABLE_ONLY"] # Checks against including unstable Tensor APIs + else: + flash_api_source = "flash_api.cpp" + sources = ( - ["flash_api.cpp"] + [flash_api_source] + (sources_fwd_sm80 if not DISABLE_SM8x else []) + sources_fwd_sm90 + (sources_bwd_sm80 if not DISABLE_SM8x else []) + sources_bwd_sm90 ) @@ -606,13 +621,14 @@ def nvcc_threads_args(): ext_modules.append( CUDAExtension( - name="flash_attn_3_cuda", + name=f"{PACKAGE_NAME}._C", sources=sources, extra_compile_args={ - "cxx": ["-O3", "-std=c++17"] + feature_args, + "cxx": ["-O3", "-std=c++17", "-DPy_LIMITED_API=0x03090000"] + stable_args + feature_args, "nvcc": nvcc_threads_args() + nvcc_flags + cc_flag + feature_args, }, include_dirs=include_dirs, + py_limited_api=True, ) ) @@ -699,7 +715,7 @@ def run(self): "benchmarks", ) ), - py_modules=["flash_attn_interface"], + py_modules=["flash_attn_interface", "flash_attn_config"], description="FlashAttention-3", long_description=long_description, long_description_content_type="text/markdown", @@ -721,4 +737,5 @@ def run(self): "packaging", "ninja", ], + options={"bdist_wheel": {"py_limited_api": "cp39"}}, ) diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index cbb7f70089b..fe6808a0acd 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -5,6 +5,12 @@ import pytest import torch import torch.nn.functional as F +from torch._C import parse_schema +from torch.testing._internal.optests.generate_tests import ( + safe_fake_check, + safe_schema_check, + safe_aot_autograd_check, +) from einops import rearrange, repeat try: @@ -37,25 +43,8 @@ DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE" DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" -DISABLE_HDIMDIFF64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF64", "FALSE") == "TRUE" -DISABLE_HDIMDIFF192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF192", "FALSE") == "TRUE" - -DISABLE_BACKWARD = True -# DISABLE_SPLIT = True -# DISABLE_PAGEDKV = True -# DISABLE_APPENDKV = True -# DISABLE_LOCAL = True -# DISABLE_SOFTCAP = True -# DISABLE_PACKGQA = True -# DISABLE_FP16 = True -# DISABLE_FP8 = True -# DISABLE_HDIM64 = True -# DISABLE_HDIM96 = True -# DISABLE_HDIM128 = True -# DISABLE_HDIM192 = True -# DISABLE_HDIM256 = True -DISABLE_HDIMDIFF64 = True -# DISABLE_HDIMDIFF192 = True +ENABLE_OPCHECK = os.getenv("FLASH_ATTENTION_ENABLE_OPCHECK", "FALSE") == "TRUE" +ENABLE_AUTOGRAD_CHECK = os.getenv("FLASH_ATTENTION_ENABLE_AUTOGRAD_CHECK", "FALSE") == "TRUE" COMPILED_HDIMS = ( [] @@ -66,6 +55,61 @@ + ([256] if not DISABLE_HDIM256 else []) ) +def should_test_backward(args, kwargs): + v = args[2] + num_splits = kwargs.get("num_splits", 1) + dtype = v.dtype + has_qv = V_colmajor = False # no test runs this with V_colmajor or has_qv == True + attention_chunk = kwargs.get("attention_chunk") + dv = v.size(-1) + + if ( + ENABLE_AUTOGRAD_CHECK + and not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not V_colmajor + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + and num_splits > 0 # we don't support num_split == 0 on torch.compile yet + ): + return True + return False + + +def should_run_schema_check(args, kwargs): + v = args[2] + if v.dtype == torch.float8_e4m3fn: + return False + return True + + +def should_run_fake_check(args, kwargs): + if 'num_splits' in kwargs: + return kwargs['num_splits'] > 0 + return True + + +def run_opcheck(fn): + def wrapper(*args, **kwargs): + if should_run_schema_check(args, kwargs): + safe_schema_check(fn, args, kwargs) + + if should_run_fake_check(args, kwargs): + safe_fake_check(fn, args, kwargs) + + if should_test_backward(args, kwargs): + # Expensive check + safe_aot_autograd_check(fn, args, kwargs, dynamic=False) + safe_aot_autograd_check(fn, args, kwargs, dynamic=True) + return fn(*args, **kwargs) + return wrapper + + +if ENABLE_OPCHECK: + flash_attn_func = run_opcheck(flash_attn_func) + flash_attn_varlen_func = run_opcheck(flash_attn_varlen_func) + # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) @@ -73,8 +117,8 @@ # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) -@pytest.mark.parametrize("has_qv_", [False] + ([True] if not DISABLE_HDIMDIFF64 else [])) -# @pytest.mark.parametrize("has_qv_", [True]) +@pytest.mark.parametrize("has_qv", [False, True]) +# @pytest.mark.parametrize("has_qv", [True]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [True]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) @@ -94,8 +138,6 @@ # @pytest.mark.parametrize("d", [64, 96, 128, 192]) @pytest.mark.parametrize("d", COMPILED_HDIMS) # @pytest.mark.parametrize("d", [64]) -@pytest.mark.parametrize("test_sink", [False, True]) -# @pytest.mark.parametrize("test_sink", [False]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -133,17 +175,12 @@ ], ) def test_flash_attn_output( - seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv_, mha_type, dtype, test_sink, - cp_world_size, cp_rank, cp_tot_seqlen_k_offset + seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype ): if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") - if has_qv_ and (d != 64 or dtype == torch.float8_e4m3fn): + if has_qv and (d != 64 or dtype == torch.float8_e4m3fn): pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") - if test_sink and has_qv_: - pytest.skip("Sink disabled for Qv") - if cp_world_size > 1 and local: - pytest.skip("context parallelism is not supported with local attention yet") device = "cuda" # set seed torch.random.manual_seed(0) @@ -155,22 +192,14 @@ def test_flash_attn_output( # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - if d == 192 and not DISABLE_HDIMDIFF192: - dv_vals = [128, d] - elif d == 64 and not DISABLE_HDIMDIFF64 and dtype != torch.float8_e4m3fn: - dv_vals = [256, 512, d] - else: - dv_vals = [d] - s_aux = torch.randn(nheads, device=device, dtype=torch.bfloat16) * 4 if test_sink else None - # s_aux = torch.ones(nheads, device=device, dtype=torch.bfloat16) * 4 if test_sink else None - # print("s_aux ", s_aux) - cp_tot_seqlen_k = seqlen_k * cp_world_size + cp_tot_seqlen_k_offset - cp_tot_seqlen_k = torch.full((batch_size,), cp_tot_seqlen_k, device=device, dtype=torch.int32) - if test_sink: + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: dv_vals = [d] - for dv in dv_vals: - print("dv =", dv) - has_qv = has_qv_ and d == 64 and dv >= 256 + if has_qv: + dv_vals = [256, 512] + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. @@ -183,7 +212,7 @@ def test_flash_attn_output( else: qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test - window_size = (-1, -1) if not local else torch.randint(0, cp_tot_seqlen_k[0], (2,)) + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)).tolist() # window_size = (-1, -1) if not local else (16, 0) if dtype == torch.float8_e4m3fn: q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] @@ -203,11 +232,8 @@ def test_flash_attn_output( qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, - softcap=softcap, - s_aux=s_aux, - cp_world_size=cp_world_size, - cp_rank=cp_rank, - cp_tot_seqlen_k=cp_tot_seqlen_k, + attention_chunk=attention_chunk, + softcap=softcap ) out_pt, attn_pt = attention_ref( q_ref, @@ -219,6 +245,7 @@ def test_flash_attn_output( qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, + attention_chunk=attention_chunk, softcap=softcap, upcast=False, reorder_ops=True, @@ -247,7 +274,8 @@ def test_flash_attn_output( pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): - out, lse = flash_attn_func( + print(f"{pack_gqa = }, {num_splits = }") + out = flash_attn_func( q, k, v, @@ -255,16 +283,11 @@ def test_flash_attn_output( qv=qv, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, + attention_chunk=attention_chunk, softcap=softcap, pack_gqa=pack_gqa, - num_splits=num_splits, - s_aux=s_aux, - cp_world_size=cp_world_size, - cp_rank=cp_rank, - cp_tot_seqused_k=cp_tot_seqlen_k, + num_splits=num_splits ) - print("Pack GQA =", pack_gqa) - print("Num splits =", num_splits) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") # if not causal: @@ -275,64 +298,68 @@ def test_flash_attn_output( # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor and not has_qv and not test_sink: - g = torch.randn_like(out) - do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) - # import flash_attn_3_cuda - # dq, dk, dv, softmax_d, dq_accum, dk_accum, dv_accum = flash_attn_3_cuda.bwd( - # g, - # q, - # k, - # v, - # out, - # lse, - # None, - # None, - # None, - # d ** (-0.5), - # causal, - # window_size[0], window_size[1], - # softcap, - # deterministic, - # 0, # sm_margin - # ) - dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) - # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") - # assert (softmax_d - do_o).abs().max().item() <= 1e-5 - # assert dq_accum.abs().max().item() == 0.0 - - # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) - # P = torch.softmax(qk, -1) - # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) - # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) - # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) - # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) - - # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) - dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) - dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) - print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") - print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") - print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") - print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") - print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") - print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") - print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") - print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") - print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") - print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") - print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") - print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") - # breakpoint() - - - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor and not has_qv and not test_sink: - dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol - dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol - dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + if ( + not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not V_colmajor + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + ): + g = torch.randn_like(out) + do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) + # import flash_attn_3_cuda + # dq, dk, dv, softmax_d, dq_accum, dk_accum, dv_accum = flash_attn_3_cuda.bwd( + # g, + # q, + # k, + # v, + # out, + # lse, + # None, + # None, + # None, + # d ** (-0.5), + # causal, + # window_size[0], window_size[1], + # softcap, + # deterministic, + # 0, # sm_margin + # ) + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @@ -341,8 +368,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) -# @pytest.mark.parametrize("has_qv", [False] + ([True] if not DISABLE_HDIMDIFF64 else [])) -@pytest.mark.parametrize("has_qv_", [False]) +@pytest.mark.parametrize("has_qv", [False, True]) +# @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) @@ -360,9 +387,7 @@ def test_flash_attn_output( # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) @pytest.mark.parametrize("d", COMPILED_HDIMS) -# @pytest.mark.parametrize("d", [128]) -@pytest.mark.parametrize("test_sink", [False, True]) -# @pytest.mark.parametrize("test_sink", [True]) +# @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -385,16 +410,16 @@ def test_flash_attn_output( (1024, 1024), (1023, 1024), (1024, 1023), + (1024, 1024), (2048, 2048), + (4096, 4096), ], ) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv_, mha_type, dtype, test_sink + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype, ): - if has_qv_ and (d != 64 or dtype == torch.float8_e4m3fn): + if has_qv and (d != 64 or dtype == torch.float8_e4m3fn): pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") - if test_sink and has_qv_: - pytest.skip("Sink disabled for Qv") device = "cuda" # set seed torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) @@ -403,20 +428,14 @@ def test_flash_attn_varlen_output( nheads = 6 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - if d == 192 and not DISABLE_HDIMDIFF192: - dv_vals = [128, d] - elif d == 64 and not DISABLE_HDIMDIFF64 and dtype != torch.float8_e4m3fn: - dv_vals = [256, 512, d] - else: - dv_vals = [d] - s_aux = torch.randn(nheads, device=device, dtype=torch.bfloat16) * 4 if test_sink else None - # s_aux = torch.ones(nheads, device=device, dtype=torch.bfloat16) * 4 if test_sink else None - # print("s_aux", s_aux) - if test_sink: + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: dv_vals = [d] - for dv in dv_vals: - print("dv =", dv) - has_qv = has_qv_ and d == 64 and dv >= 256 + if has_qv: + dv_vals = [256, 512] + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. @@ -493,8 +512,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, - softcap=softcap, - s_aux=s_aux, + attention_chunk=attention_chunk, + softcap=softcap ) out_pt, attn_pt = attention_ref( q_ref, @@ -506,6 +525,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, + attention_chunk=attention_chunk, softcap=softcap, upcast=False, reorder_ops=True, @@ -527,7 +547,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): - out_unpad, lse = flash_attn_varlen_func( + print(f"{pack_gqa = }, {num_splits = }") + out_unpad = flash_attn_varlen_func( q_unpad, k_unpad, v_unpad, @@ -542,13 +563,11 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, + attention_chunk=attention_chunk, softcap=softcap, - num_splits=num_splits, pack_gqa=pack_gqa, - s_aux=s_aux, + num_splits=num_splits, ) - print("Pack GQA =",pack_gqa) - print("Num splits =",num_splits) out = output_pad_fn(out_unpad) if query_unused_mask is not None: out.masked_fill_(q_zero_masking, 0.0) @@ -563,81 +582,85 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv and not test_sink: - g_unpad = torch.randn_like(out_unpad) - do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) - # import flash_attn_3_cuda - # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen( - # g_unpad, - # q_unpad, - # k_unpad, - # v_unpad, - # out_unpad, - # lse, - # None, - # None, - # None, - # cu_seqlens_q, - # cu_seqlens_k, - # None, None, - # max_seqlen_q, - # max_seqlen_k, - # d ** (-0.5), - # causal, - # window_size[0], window_size[1], - # softcap, - # deterministic, - # 0, # sm_margin - # ) - dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad) - dq = dq_pad_fn(dq_unpad) - dk = dk_pad_fn(dk_unpad) - dv = dk_pad_fn(dv_unpad) - if key_unused_mask is not None: - k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") - dk.masked_fill_(k_zero_masking, 0.0) - dv.masked_fill_(k_zero_masking, 0.0) - if query_unused_mask is not None: - dq.masked_fill_(q_zero_masking, 0.0) - # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") - # assert (softmax_d - do_o).abs().max().item() <= 1e-5 - # assert dq_accum.abs().max().item() == 0.0 - g = output_pad_fn(g_unpad) - - # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() - # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) - # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) - # P = torch.softmax(qk, -1) - # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) - # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) - # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) - # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) - - - # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) - dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) - dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) - print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") - print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") - print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") - print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") - print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") - print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") - print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") - print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") - print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") - print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") - print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") - print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") - # breakpoint() - - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv and not test_sink: - dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol - dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol - dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + if ( + not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + ): + g_unpad = torch.randn_like(out_unpad) + do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) + # import flash_attn_3_cuda + # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen( + # g_unpad, + # q_unpad, + # k_unpad, + # v_unpad, + # out_unpad, + # lse, + # None, + # None, + # None, + # cu_seqlens_q, + # cu_seqlens_k, + # None, None, + # max_seqlen_q, + # max_seqlen_k, + # d ** (-0.5), + # causal, + # window_size[0], window_size[1], + # softcap, + # deterministic, + # 0, # sm_margin + # ) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad) + dq = dq_pad_fn(dq_unpad) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) + if query_unused_mask is not None: + dq.masked_fill_(q_zero_masking, 0.0) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + g = output_pad_fn(g_unpad) + + # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() + # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @@ -652,9 +675,9 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) # @pytest.mark.parametrize("causal,local", [(True, False)]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]) -# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) -@pytest.mark.parametrize("has_rotary_seqlens", [False, True]) -# @pytest.mark.parametrize("has_rotary_seqlens", [False]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [False]) +# @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) +@pytest.mark.parametrize("has_rotary_seqlens", [False]) @pytest.mark.parametrize("rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False]) # @pytest.mark.parametrize("rotary_interleaved", [False]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0] if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) else [0.0]) @@ -674,8 +697,6 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # @pytest.mark.parametrize("d", COMPILED_HDIMS) @pytest.mark.parametrize("d", [128]) # @pytest.mark.parametrize("d", [192]) -# @pytest.mark.parametrize("test_sink", [False, True]) -@pytest.mark.parametrize("test_sink", [False]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -738,12 +759,11 @@ def test_flash_attn_kvcache( assert nheads % nheads_k == 0 dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) - # dv_vals = [d] - s_aux = torch.randn(nheads, device=device, dtype=torch.bfloat16) * 4 if test_sink else None - if dtype == torch.float8_e4m3fn and d != 192: + if dtype == torch.float8_e4m3fn: dv_vals = [d] - for dv in dv_vals: - print("dv =", dv) + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) and not DISABLE_LOCAL else [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") has_qv = d == 64 and dv >= 256 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) if has_qv: @@ -891,8 +911,8 @@ def test_flash_attn_kvcache( causal=causal, qv=qv, window_size=window_size, + attention_chunk=attention_chunk, key_leftpad=cache_leftpad, - s_aux=s_aux, ) out_pt, _ = attention_ref( q_ro, @@ -903,6 +923,7 @@ def test_flash_attn_kvcache( causal=causal, qv=qv, window_size=window_size, + attention_chunk=attention_chunk, upcast=False, reorder_ops=True, key_leftpad=cache_leftpad, @@ -928,9 +949,7 @@ def test_flash_attn_kvcache( num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1] precompute_metadata_vals = [False, True] for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): - print("Num splits = ",num_splits) - print("Precompute metadata = ",precompute_metadata) - # print("max seqlen_q, seqlen_q ", max_seqlen_q, seqlen_q) + print(f"{num_splits = }, {precompute_metadata = }") if precompute_metadata: # WARNING: seqlen_k is not max_seqlen_k if using page table, so we can't expect this to make sense? scheduler_metadata = get_scheduler_metadata( @@ -941,8 +960,8 @@ def test_flash_attn_kvcache( cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, max_seqlen_k_new=seqlen_new, page_size=page_size, - causal=causal, window_size=window_size, - num_splits=num_splits + causal=causal, window_size=window_size, attention_chunk=attention_chunk, + num_splits=num_splits, ) else: scheduler_metadata = None @@ -974,11 +993,11 @@ def test_flash_attn_kvcache( rotary_seqlens=rotary_seqlens, causal=causal, window_size=window_size, + attention_chunk=attention_chunk, rotary_interleaved=rotary_interleaved, scheduler_metadata=scheduler_metadata, num_splits=num_splits, return_softmax_lse=True, - s_aux=s_aux, ) if varlen_q: out = output_pad_fn(out) @@ -1133,7 +1152,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype): k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) torch.random.manual_seed(42) - out0, lse0 = flash_attn_func(q, k, v, causal=causal) + out0 = flash_attn_func(q, k, v, causal=causal) g = torch.randn_like(out0) dq0, dk0, dv0 = torch.autograd.grad(out0, (q, k, v), g) # Numerical error if we just do any arithmetic on dq @@ -1141,9 +1160,9 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype): for i in range(1000): torch.random.manual_seed(42) - out, lse = flash_attn_func(q, k, v, causal=causal) + out = flash_attn_func(q, k, v, causal=causal) assert torch.equal(out, out0) - assert torch.equal(lse, lse0) + # assert torch.equal(lse, lse0) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_equal = torch.allclose(dq, dq0, atol=dq_atol) @@ -1212,3 +1231,43 @@ def test_flash_attn_combine(num_splits, seqlen, d, dtype): # # pytorch_profiler(torch.sum, lse_partial) # pytorch_profiler(flash_attn_combine, out_partial, lse_partial) # pytorch_profiler(torch.sum, out_partial) + +def test_flash3_bw_compatibility() -> None: + # Let's try to always stay backward compatible! This will make life easier + # for downstream libaries, users, and exported models. + # 1/ Instead of removing arguments, error out if their value is no longer supported + # 2/ When adding arguments, add them at the end with a default value + assert torch.ops.flash_attn_3.fwd.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::fwd(Tensor q, Tensor k, Tensor v, Tensor(k_new!)? k_new=None, " + "Tensor(v_new!)? v_new=None, Tensor? q_v=None, Tensor(out!)? out=None, " + "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, " + "Tensor? cu_seqlens_k_new=None, Tensor? seqused_q=None, Tensor? seqused_k=None, " + "int? max_seqlen_q=None, int? max_seqlen_k=None, Tensor? page_table=None, " + "Tensor? kv_batch_idx=None, Tensor? leftpad_k=None, Tensor? rotary_cos=None, Tensor? rotary_sin=None, " + "Tensor? seqlens_rotary=None, Tensor? q_descale=None, Tensor? k_descale=None, Tensor? v_descale=None, " + "float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, int window_size_right=-1, " + "int attention_chunk=0, float softcap=0., bool is_rotary_interleaved=False, " + "Tensor? scheduler_metadata=None, int num_splits=0, bool? pack_gqa=None, int sm_margin=0) " + "-> (Tensor(out!), Tensor, Tensor, Tensor)" + )) + assert torch.ops.flash_attn_3.bwd.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, " + "Tensor(dq!)? dq=None, Tensor(dk!)? dk=None, Tensor(dv!)? dv=None, Tensor? cu_seqlens_q=None, " + "Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, " + "int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, " + "int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) " + "-> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)" + )) + assert torch.ops.flash_attn_3.fwd_combine.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::fwd_combine(Tensor out_partial, Tensor lse_partial, Tensor(out!)? out=None, " + "ScalarType? out_dtype=None) -> (Tensor(out!), Tensor)" + )) + assert torch.ops.flash_attn_3.get_scheduler_metadata.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::get_scheduler_metadata(int batch_size, int max_seqlen_q, int max_seqlen_k, " + "int num_heads, int num_heads_k, int headdim, int headdim_v, ScalarType qkv_dtype, Tensor seqused_k, " + "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, Tensor? cu_seqlens_k_new=None, " + "Tensor? seqused_q=None, Tensor? leftpad_k=None, int? page_size=None, int max_seqlen_k_new=0, " + "bool is_causal=False, int window_size_left=-1, int window_size_right=-1, " + "int attention_chunk=0, bool has_softcap=False, int num_splits=0, bool? pack_gqa=None, " + "int sm_margin=0) -> Tensor" + )) diff --git a/hopper/test_flash_attn_bwd_determinism.py b/hopper/test_flash_attn_bwd_determinism.py new file mode 100644 index 00000000000..b443c8948d4 --- /dev/null +++ b/hopper/test_flash_attn_bwd_determinism.py @@ -0,0 +1,706 @@ +import os +import math +import itertools + +import pytest +import torch +import torch.nn.functional as F +from torch._C import parse_schema + +from einops import rearrange, repeat +try: + from flash_attn.layers.rotary import apply_rotary_emb +except ImportError: + apply_rotary_emb = None + +from padding import pad_input, unpad_input +from test_util import ( + attention_ref, + generate_qkv, + generate_random_padding_mask, +) + +from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine +from flash_attn_interface import flash_attn_with_kvcache, get_scheduler_metadata + +from flash_attn_interface import _flash_attn_backward + + +DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" +DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" +DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE" +DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE" +DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "FALSE") == "TRUE" +DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE" +DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE" +DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE" +DISABLE_FP8 = os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE" or torch.cuda.get_device_capability("cuda")[0] < 9 +DISABLE_HDIM64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM64", "FALSE") == "TRUE" +DISABLE_HDIM96 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM96", "FALSE") == "TRUE" +DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE" +DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" +DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" + +# deterministic mode not supported for hdim 256 +DISABLE_HDIM256 = True + +COMPILED_HDIMS = ( + [] + + ([64] if not DISABLE_HDIM64 else []) + + ([96] if not DISABLE_HDIM96 else []) + + ([128] if not DISABLE_HDIM128 else []) + + ([192] if not DISABLE_HDIM192 else []) + + ([256] if not DISABLE_HDIM256 else []) +) + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mqa"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +@pytest.mark.parametrize("deterministic", [False, True]) +# @pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +# @pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("local", [True]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("V_colmajor", [False, True]) +@pytest.mark.parametrize("V_colmajor", [False]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64, 128, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128, 192]) +@pytest.mark.parametrize("d", COMPILED_HDIMS) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (64, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (4096, 4096), + # (4224, 4224), + # (8192, 8192), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +def test_flash_attn_output( + seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype +): + if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): + pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") + if has_qv and (d != 64 or dtype == torch.float8_e4m3fn): + pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") + if deterministic and d == 256: + pytest.skip("Deterministic mode not supported for hdim 256") + device = "cuda" + # set seed + torch.random.manual_seed(0) + # batch_size = 40 + # nheads = 16 + batch_size = 9 if seqlen_k <= 2048 else 2 + # batch_size = 1 + nheads = 6 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + # if dtype == torch.float8_e4m3fn: + # dv_vals = [d] + # if has_qv: + # dv_vals = [256, 512] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] + dv_vals = [d] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4) + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + # window_size = (-1, -1) if not local else (16, 0) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None + if V_colmajor: + v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() + # if qv is not None: + # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # exp_sum = s_tmp.sum(-1) + # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) + # lse_ref = torch.logsumexp(qk, dim=-1) + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + # num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + pack_gqa_vals = [False] + num_splits_vals = [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + print(f"{pack_gqa = }, {num_splits = }") + out, softmax_lse = flash_attn_func( + q, + k, + v, + causal=causal, + qv=qv, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + pack_gqa=pack_gqa, + num_splits=num_splits, + return_attn_probs=True, + ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + + if ( + not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not V_colmajor + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + ): + g = torch.randn_like(out) + do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dq, dk, dv, softmax_d = _flash_attn_backward( + g, + q, + k, + v, + out, + softmax_lse, + None, None, # cu_seqlens_q, cu_seqlens_k, + None, None, # sequed_q, sequed_k, + None, None, # max_seqlen_q, max_seqlen_k, + dq, + dk, + dv, + d ** (-0.5), + causal, + window_size=window_size, + softcap=softcap, + deterministic=deterministic, + ) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + + if deterministic: + iterations = 1000 + + for i in range(iterations): + dq2 = torch.empty_like(dq) + dk2 = torch.empty_like(dk) + dv2 = torch.empty_like(dv) + dq2, dk2, dv2, softmax_d = _flash_attn_backward( + g, + q, + k, + v, + out, + softmax_lse, + None, None, # cu_seqlens_q, cu_seqlens_k, + None, None, # sequed_q, sequed_k, + None, None, # max_seqlen_q, max_seqlen_k, + dq2, + dk2, + dv2, + d ** (-0.5), + causal, + window_size=window_size, + softcap=softcap, + deterministic=deterministic, + ) + print(f'dq max diff with myself: {(dq2 - dq).abs().max().item()}') + print(f'dk max diff with myself: {(dk2 - dk).abs().max().item()}') + print(f'dv max diff with myself: {(dv2 - dv).abs().max().item()}') + assert torch.equal(dq, dq2), f"dq not deterministic" + assert torch.equal(dk, dk2), f"dk not deterministic" + assert torch.equal(dv, dv2), f"dv not deterministic" + print(f"✅ Iteration {i} passed!") + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +@pytest.mark.parametrize("deterministic", [False, True]) +# @pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +# @pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("add_unused_qkv", [False, True]) +# @pytest.mark.parametrize("add_unused_qkv", [True]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128]) +@pytest.mark.parametrize("d", COMPILED_HDIMS) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (1, 3), + (2, 1), + (511, 1), + (3, 513), + (64, 128), + (128, 128), + (256, 256), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (307, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (1024, 1024), + (2048, 2048), + (4096, 4096), + ], +) +def test_flash_attn_varlen_output( + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype, +): + if has_qv and (d != 64 or dtype == torch.float8_e4m3fn): + pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") + if deterministic and d == 256: + pytest.skip("Deterministic mode not supported for hdim 256") + device = "cuda" + # set seed + torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) + # batch_size = 40 + # nheads = 16 + batch_size = 9 if seqlen_q <= 2048 else 2 + # batch_size = 32 + nheads = 6 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + # batch_size = 2 + # nheads = 1 + # nheads_kv = nheads + + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + # if dtype == torch.float8_e4m3fn: + # dv_vals = [d] + # if has_qv: + # dv_vals = [256, 512] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] + dv_vals = [d] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random", zero_lengths=False + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, batch_size, device, mode="random", zero_lengths=True + ) + + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask + + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) + + ( + q_unpad, + k_unpad, + v_unpad, + qv_unpad, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + qv, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, + query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) + q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + # pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + # num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1] + pack_gqa_vals = [False] + num_splits_vals = [1] + print("cu_seqlens_q: ", cu_seqlens_q) + print("cu_seqlens_k: ", cu_seqlens_k) + print("seqused_q: ", seqused_q) + print("seqused_k: ", seqused_k) + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + print(f"{pack_gqa = }, {num_splits = }") + out_unpad, softmax_lse = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + causal=causal, + qv=qv_unpad, + q_descale=q_descale, + k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + pack_gqa=pack_gqa, + num_splits=num_splits, + deterministic=deterministic, + return_attn_probs=True, + ) + out = output_pad_fn(out_unpad) + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + + + if ( + not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + ): + g_unpad = torch.randn_like(out_unpad) + do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) + dq_unpad = torch.empty_like(q_unpad) + dk_unpad = torch.empty_like(k_unpad) + dv_unpad = torch.empty_like(v_unpad) + dq_unpad, dk_unpad, dv_unpad, softmax_d = _flash_attn_backward( + g_unpad, + q_unpad, + k_unpad, + v_unpad, + out_unpad, + softmax_lse, + cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, + max_seqlen_q, max_seqlen_k, + dq_unpad, + dk_unpad, + dv_unpad, + d ** (-0.5), + causal, + window_size=window_size, + softcap=softcap, + deterministic=deterministic, + ) + dq = dq_pad_fn(dq_unpad) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) + if query_unused_mask is not None: + dq.masked_fill_(q_zero_masking, 0.0) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + g = output_pad_fn(g_unpad) + + # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() + # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + + print(dq_unpad.shape) + print(dk_unpad.shape) + print(dv_unpad.shape) + + print(dq.shape) + print(dk.shape) + print(dv.shape) + + if deterministic: + iterations = 1000 + + for i in range(iterations): + dq_unpad2 = torch.empty_like(q_unpad) + dk_unpad2 = torch.empty_like(k_unpad) + dv_unpad2 = torch.empty_like(v_unpad) + dq_unpad2, dk_unpad2, dv_unpad2, softmax_d = _flash_attn_backward( + g_unpad, + q_unpad, + k_unpad, + v_unpad, + out_unpad, + softmax_lse, + cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, + max_seqlen_q, max_seqlen_k, + dq_unpad2, + dk_unpad2, + dv_unpad2, + d ** (-0.5), + causal, + window_size=window_size, + softcap=softcap, + deterministic=deterministic, + ) + + dq2 = dq_pad_fn(dq_unpad2) + dk2 = dk_pad_fn(dk_unpad2) + dv2 = dk_pad_fn(dv_unpad2) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk2.masked_fill_(k_zero_masking, 0.0) + dv2.masked_fill_(k_zero_masking, 0.0) + if query_unused_mask is not None: + dq2.masked_fill_(q_zero_masking, 0.0) + + print(f'dq max diff with myself: {(dq2 - dq).abs().max().item()}') + print(f'dk max diff with myself: {(dk2 - dk).abs().max().item()}') + print(f'dv max diff with myself: {(dv2 - dv).abs().max().item()}') + + assert torch.equal(dq, dq2), f"dq not deterministic" + assert torch.equal(dk, dk2), f"dk not deterministic" + assert torch.equal(dv, dv2), f"dv not deterministic" + + print(f"✅ Iteration {i} passed!") \ No newline at end of file diff --git a/hopper/test_torch_compile_and_export.py b/hopper/test_torch_compile_and_export.py new file mode 100644 index 00000000000..53beef46340 --- /dev/null +++ b/hopper/test_torch_compile_and_export.py @@ -0,0 +1,73 @@ +import torch +from flash_attn_interface import flash_attn_func +from torch import nn + + +class EfficienctMultiHeadAttention(nn.Module): + def __init__(self, embed_size, num_heads, dropout=0.0, use_flash_attn=True): + super().__init__() + assert embed_size % num_heads == 0, f"{embed_size=} {num_heads=}" + + self.embed_size = embed_size + self.num_heads = num_heads + self.head_dim = embed_size // num_heads + self.use_flash_attn = use_flash_attn and (flash_attn_func is not None) + + self.qkv_proj = nn.Linear(embed_size, 3 * embed_size) + self.out_proj = nn.Linear(embed_size, embed_size) + self.dropout = dropout + + def forward(self, x, attention_mask=None): + N, seq_length, _ = x.shape + + qkv = self.qkv_proj(x) + q, k, v = qkv.chunk(3, dim=-1) + + q = q.view(N, seq_length, self.num_heads, self.head_dim) + k = k.view(N, seq_length, self.num_heads, self.head_dim) + v = v.view(N, seq_length, self.num_heads, self.head_dim) + + if self.use_flash_attn and attention_mask is None: + out = flash_attn_func( + q, k, v + ) + out = out.reshape(N, seq_length, self.embed_size) + out = self.out_proj(out) + return out + + +def create_model(batch_size=16, sequence_length=256, embedding_dim=2048, num_heads=16): + model = EfficienctMultiHeadAttention(embedding_dim, num_heads).cuda().bfloat16() + input_tensor = torch.randn(batch_size, sequence_length, embedding_dim).cuda().bfloat16() + return model, input_tensor + + +def test_export_model(): + model, input_tensor = create_model() + expected = torch.compile(model, backend="aot_eager")(input_tensor) + loss = expected.sum() + loss.backward() + + ep = torch.export.export(model, (input_tensor,)) + got = ep.module()(input_tensor,) + assert torch.equal(expected, got) + + loss_2 = got.sum() + loss_2.backward() + + assert torch.equal(loss, loss_2) + + +def test_compile_and_package_model(): + model, input_tensor = create_model() + expected = torch.compile(model, backend="aot_eager")(input_tensor) + + exported = torch.export.export(model, (input_tensor,)) + torch._inductor.aoti_compile_and_package( + exported, + package_path="model.pt2", + ) + + compiled_model = torch._inductor.package.load_package("model.pt2") + out = compiled_model(input_tensor,) + assert torch.equal(expected, out) diff --git a/hopper/test_util.py b/hopper/test_util.py index a24faf96d34..9b7135d3a9d 100644 --- a/hopper/test_util.py +++ b/hopper/test_util.py @@ -307,6 +307,56 @@ def construct_cp_mask( return mask +def construct_chunk_mask( + seqlen_q, + seqlen_k, + attention_chunk, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + device=None, + cp_world_size=1, + cp_rank=0, + cp_tot_seqlen_k=None, +): + if cp_world_size > 1: + return construct_cp_mask( + seqlen_q, + seqlen_k, + cp_world_size=cp_world_size, + cp_rank=cp_rank, + cp_tot_seqlen_k=cp_tot_seqlen_k, + window_size=window_size, + sink_token_length=sink_token_length, + query_padding_mask=query_padding_mask, + key_padding_mask=key_padding_mask, + key_leftpad=key_leftpad, + device=device, + ) + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + # Subtract remainder instead of divide and then multiply to take care of negative values + col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk + return torch.logical_or( + col_idx < col_limit_left_chunk, col_idx >= col_limit_left_chunk + attention_chunk + ) + + def attention_ref( q, k, @@ -321,6 +371,7 @@ def attention_ref( qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), # -1 means infinite window size + attention_chunk=0, sink_token_length=0, softcap=0.0, upcast=True, @@ -334,8 +385,8 @@ def attention_ref( """ Arguments: q: (batch_size, seqlen_q, nheads, head_dim) - k: (batch_size, seqlen_k, nheads_kv, head_dim) - v: (batch_size, seqlen_k, nheads_kv, head_dim_v) + k: (batch_size, seqlen_k, nheads, head_dim) + v: (batch_size, seqlen_k, nheads, head_dim_v) qv: (batch_size, seqlen_q, nheads, head_dim_v) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) @@ -365,7 +416,6 @@ def attention_ref( if upcast: q, k, v = q.float(), k.float(), v.float() qv = qv.float() if qv is not None else None - s_aux = s_aux.float() if s_aux is not None else None if q_descale is not None: q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2]) q = (q.float() * q_descale).to(q.dtype) @@ -389,6 +439,7 @@ def attention_ref( scores = torch.tanh(scores / softcap) * softcap if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + local_mask = None if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( seqlen_q, @@ -403,6 +454,18 @@ def attention_ref( cp_rank=cp_rank, cp_tot_seqlen_k=cp_tot_seqlen_k, ) + if attention_chunk > 0: + chunk_mask = construct_chunk_mask( + seqlen_q, + seqlen_k, + attention_chunk, + query_padding_mask, + key_padding_mask, + key_leftpad=key_leftpad, + device=q.device, + ) + local_mask = torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask + if local_mask is not None: scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias @@ -422,7 +485,7 @@ def attention_ref( if key_padding_mask is not None: attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) # Some rows might be completely masked out so we fill them with zero instead of NaN - if window_size[0] >= 0 or window_size[1] >= 0: + if local_mask is not None: attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) dropout_scaling = 1.0 / (1 - dropout_p) # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 1109de2b8a4..7227386a2d4 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -25,7 +25,7 @@ struct TileSchedulerArguments { int const* const cu_seqlens = nullptr; int const* const seqused = nullptr; int const* const num_splits_dynamic_ptr = nullptr; - int const* const prepare_seqlen_q_ptr = nullptr; + int const* const num_m_blocks_ptr = nullptr; int const* const varlen_batch_idx_ptr = nullptr; // int const* const num_n_blocks_ptr = nullptr; int const* const num_nheads_in_l2_ptr = nullptr; @@ -255,12 +255,14 @@ class DynamicPersistentTileScheduler { static Params to_underlying_arguments(TileSchedulerArguments const& args) { - int const size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size * 2; + long long const size_one_kv_head = long(args.seqlen_k) * long(args.headdim + args.headdim_v) * long(args.element_size); int const size_l2 = 32 * 1024 * 1024; // 32 MB for K & V // Swizzle is the size of each "section". Round swizzle to a power of 2 // If not PackGQA already, the size of each section can increase by qhead_per_khead // Need to be careful about the case where only one head will fit - int const swizzle = (size_l2 < size_one_kv_head ? 1 : (1 << cutlass::find_log2(size_l2 / size_one_kv_head))) * (PackGQA ? 1 : args.qhead_per_khead); + auto find_log2_floor = [&](int n) { return 31 - cutlass::clz(n); }; + // Seems faster if swizzle if a power of 2 + int const swizzle = (size_l2 < size_one_kv_head ? 1 : (1 << find_log2_floor(size_l2 / size_one_kv_head))) * (PackGQA ? 1 : args.qhead_per_khead); // If we're in the last section (called residual), we don't want to divide by // swizzle. Instead we want to divide by the remainder. int const num_hb_remainder = (args.num_head * args.num_batch) % swizzle; @@ -364,6 +366,132 @@ class DynamicPersistentTileScheduler { }; +/////////////////////////////////////////////////////////////////////////////// + +template +class SingleTileBwdLPTScheduler { + +public: + + using SharedStorage = int; + + // Device side kernel params + struct Params { + int const total_blocks; + cutlass::FastDivmod const block_divmod, head_divmod; + cutlass::FastDivmod const l2_minor_divmod, l2_major_divmod; + cutlass::FastDivmod const l2_minor_residual_divmod; + int const num_hb_quotient; + int const seqlen; + int const* const cu_seqlens; + int const* const seqused; + }; + + static Params + to_underlying_arguments(TileSchedulerArguments const& args) { + // Since it's the bwd pass, seqlen_k get passed to args.seqlen and seqlen_q is passed to args.seqlen_k + long long const size_one_qdo_head = long(args.seqlen_k) * long(args.headdim + args.headdim_v) * long(args.element_size); + long long const size_one_dqaccum_head = long(args.seqlen_k) * long(args.headdim) * sizeof(float); + long long const size_one_head = size_one_qdo_head + size_one_dqaccum_head; + int const size_l2 = 40 * 1024 * 1024; // 40 MB for Q, dO, and dQaccum + // Swizzle is the size of each "section". Round swizzle to a power of 2 + // Need to be careful about the case where only one head will fit + auto find_log2_floor = [&](int n) { return 31 - cutlass::clz(n); }; + // Seems faster if swizzle if a power of 2 + int const swizzle = size_l2 < size_one_head ? 1 : (1 << find_log2_floor(size_l2 / size_one_head)); + // If we're in the last section (called residual), we don't want to divide by + // swizzle. Instead we want to divide by the remainder. + int const num_hb_remainder = (args.num_head * args.num_batch) % swizzle; + // printf("num_blocks = %d, num_head = %d, num_batch = %d, size_one_head = %d, ratio = %d, swizzle = %d, num_hb_remainder = %d\n", args.num_blocks, args.num_head, args.num_batch, size_one_head, size_l2 / size_one_head, swizzle, num_hb_remainder); + assert(args.tile_count_semaphore != nullptr); + return {args.num_blocks * args.num_head * args.num_batch, + cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head), + cutlass::FastDivmod(swizzle), cutlass::FastDivmod(swizzle * args.num_blocks), + // don't divide by 0 + cutlass::FastDivmod(num_hb_remainder > 0 ? num_hb_remainder : 1), + (args.num_head * args.num_batch) / swizzle, + args.seqlen, !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused}; + } + + static dim3 + get_grid_shape(Params const& params, int num_sm) { + return {uint32_t(params.total_blocks)}; + } + + struct WorkTileInfo { + int block; + int bidh; + int bidb; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return bidb >= 0; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + return {block, bidh, bidb, 0 /*split_idx*/}; + } + + }; + + CUTLASS_DEVICE + SingleTileBwdLPTScheduler(SharedStorage* const smem_scheduler) { } + + template + CUTLASS_DEVICE + WorkTileInfo + get_initial_work(Params const& params) const { + int tile_idx = blockIdx.x; + int block, bidh, bidb; + int l2_mod, bidhb, bidhb_residual; + bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx); + // If we're in the last section (called residual), we don't want to divide by + // swizzle. Instead we want to divide by the remainder. + if (bidhb < params.num_hb_quotient) { + block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod); + } else { + block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod); + } + bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual); + bool is_valid_tile = true; + int num_blocks; + if constexpr (Varlen) { + int seqlen = params.seqused + ? params.seqused[bidb] + : (params.cu_seqlens ? params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb] : params.seqlen); + num_blocks = cute::ceil_div(seqlen, Int{}); + is_valid_tile = block < num_blocks; + } else { + num_blocks = params.block_divmod.divisor; + } + if constexpr (SPT) { + block = num_blocks - block - 1; + } + return {block, bidh, is_valid_tile ? bidb : -1}; + } + + CUTLASS_DEVICE + void + init_consumer() const {} + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + return {0, 0, -1}; + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + template class VarlenDynamicPersistentTileScheduler { @@ -391,7 +519,7 @@ class VarlenDynamicPersistentTileScheduler { int const* const cu_seqlens; int const* const seqused; int const* const num_splits_dynamic_ptr; - int const* const prepare_seqlen_q_ptr; + int const* const num_m_blocks_ptr; int const* const varlen_batch_idx_ptr; // int const* const num_n_blocks_ptr; int const* const num_nheads_in_l2_ptr; @@ -414,7 +542,7 @@ class VarlenDynamicPersistentTileScheduler { cutlass::FastDivmod(!Split ? 1 : args.num_splits), args.tile_count_semaphore, args.cu_seqlens, args.seqused, args.num_splits_dynamic_ptr, - args.prepare_seqlen_q_ptr, + args.num_m_blocks_ptr, args.varlen_batch_idx_ptr, // aras.num_n_blocks_ptr, args.num_nheads_in_l2_ptr}; @@ -476,7 +604,7 @@ class VarlenDynamicPersistentTileScheduler { int batch_idx = lane + bidb_start; if constexpr (Prepared) { return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 - ? cute::ceil_div(params.prepare_seqlen_q_ptr[batch_idx], kBlockM) : 0; + ? params.num_m_blocks_ptr[batch_idx] : 0; } else { int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); if (seqlen > kBlockM) { diff --git a/hopper/tile_size.h b/hopper/tile_size.h index 9b9aef5a704..3b89af883b9 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -18,35 +18,23 @@ constexpr std::tuple tile_size_fwd_sm90( if (headdim_v == 512) { return {64, 64, false, false}; } else if (headdim_v == 256) { - return {128, 112, true, false}; + return {128, 96, true, false}; } else { - if (use_one_mma_wg) { - return {64, 192, true, true}; - } else { - // Switch to tile size 192 x 192 for now - // bool const use_blockN_128 = is_causal || is_local; - // return {192, use_blockN_128 ? 128 : 192, use_blockN_128, true}; // BASE - // Benefits SWA when window length <= 128 - return {192, is_causal ? 128 : is_local || paged_kv_non_TMA ? 160 : 192, is_causal || is_local, !is_local}; - // return {192, is_causal ? 128 : 160, true, !is_local}; - // return {128, use_blockN_128 ? 160 : 192, use_blockN_128, !use_blockN_128}; - // return {192, is_local ? 160 : 192, true, false}; - } + // Switch to tile size 192 x 192 for now + bool const use_blockN_128 = is_causal || is_local || paged_kv_non_TMA; + return {192, use_blockN_128 ? 128 : 192, use_blockN_128, true}; } // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen // return {192, is_causal || is_local ? 192 : 176, true, false}; } else if (headdim <= 96) { return {192, is_local || paged_kv_non_TMA ? 128 : 144, false, true}; } else if (headdim <= 128) { - if (use_one_mma_wg) { - return {64, is_causal || is_local || paged_kv_non_TMA ? 128 : 176, true, true}; - } else { - return {128, is_causal || is_local || paged_kv_non_TMA ? 128 : 160, true, true}; - } - // {128, 192, false, false} and {192, 128, false, true} are quite good too + bool const use_blockN_128 = is_causal || is_local || paged_kv_non_TMA; + return {128, use_blockN_128 ? 128 : 176, true, true}; + // {128, 192, true, false} and {192, 128, false, true} are quite good too // 128 x 192 hits the limit of smem if MmaPV_is_RS, 128 x 144 hits the limit if !MmaPV_is_RS } else if (headdim <= 192) { - return {128, paged_kv_non_TMA || is_local ? 96 : (headdim_v <= 128 ? 128 : 96), true, true}; // 128 x 112 hits the limit of smem + return {128, paged_kv_non_TMA || is_local ? 96 : (headdim_v <= 128 ? 128 : 112), true, true}; // 128 x 112 hits the limit of smem } else { return {128, is_local ? 64 : 80, true, true}; // 128 x 80 hits the limit of smem } @@ -60,11 +48,7 @@ constexpr std::tuple tile_size_fwd_sm90( } else if (headdim <= 96) { return {192, 128, true, true}; } else if (headdim <= 128) { - if (use_one_mma_wg) { - return {64, 96, true, true}; - } else{ - return {128, paged_kv_non_TMA ? 160 : (v_colmajor || (softcap && is_local) ? 192 : 224), true, true}; - } + return {128, paged_kv_non_TMA ? 160 : (v_colmajor || (softcap && is_local) ? 192 : 224), true, true}; } else if (headdim <= 192) { return {128, (paged_kv_non_TMA || softcap) && is_local ? 128 : 160, true, true}; } else { diff --git a/hopper/utils.h b/hopper/utils.h index efc755f7d7f..4253ffadb6f 100644 --- a/hopper/utils.h +++ b/hopper/utils.h @@ -99,6 +99,25 @@ static __device__ __forceinline__ T run(T x, Operator &op) { //////////////////////////////////////////////////////////////////////////////////////////////////// +CUTLASS_HOST_DEVICE +int div_floor(cutlass::FastDivmod const& divmod, int dividend) { + // Take care of the negative case: https://stackoverflow.com/questions/39304681/division-with-negative-dividend-but-rounded-towards-negative-infinity + // Maybe the compiler will turn the -1 - * into bit negation operation, I haven't checked. + return dividend >= 0 ? divmod.divide(dividend) : -1 - divmod.divide(-1 - dividend); +} + +CUTLASS_HOST_DEVICE +int round_down(cutlass::FastDivmod const& divmod, int dividend) { + return div_floor(divmod, dividend) * divmod.divisor; +} + +CUTLASS_HOST_DEVICE +int round_up(cutlass::FastDivmod const& divmod, int dividend) { + return div_floor(divmod, dividend - 1) * divmod.divisor + divmod.divisor; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + // For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) template @@ -316,15 +335,6 @@ CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const } } } - -#ifndef FLASHATTENTION_DISABLE_FP8_TWO_LEVEL_ACCUMULATION - if constexpr (Use_Two_Level) { - #pragma unroll - for (int i = 0; i < cute::size(tCrC); ++i) { - tCrC(i) = tCrC_original(i) + tCrC(i); // Add temp results to original - } - } -#endif } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/setup.py b/setup.py index d2479f600ff..3543c2527ec 100644 --- a/setup.py +++ b/setup.py @@ -40,26 +40,350 @@ # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) -PACKAGE_NAME = "vllm_flash_attn" +BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto") -cmdclass = {} -ext_modules = [] +if BUILD_TARGET == "auto": + if IS_HIP_EXTENSION: + IS_ROCM = True + else: + IS_ROCM = False +else: + if BUILD_TARGET == "cuda": + IS_ROCM = False + elif BUILD_TARGET == "rocm": + IS_ROCM = True + +PACKAGE_NAME = "flash_attn" + +BASE_WHEEL_URL = ( + "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}" +) + +# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels +# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation +FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" +SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" +# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI +FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" +USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" +SKIP_CK_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CK_BUILD", "TRUE") == "TRUE" if USE_TRITON_ROCM else False +NVCC_THREADS = os.getenv("NVCC_THREADS") or "4" + +@functools.lru_cache(maxsize=None) +def cuda_archs() -> str: + return os.getenv("FLASH_ATTN_CUDA_ARCHS", "80;90;100;110;120").split(";") + + +def get_platform(): + """ + Returns the platform name as used in wheel filenames. + """ + if sys.platform.startswith("linux"): + return f'linux_{platform.uname().machine}' + elif sys.platform == "darwin": + mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) + return f"macosx_{mac_version}_x86_64" + elif sys.platform == "win32": + return "win_amd64" + else: + raise ValueError("Unsupported platform: {}".format(sys.platform)) + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def add_cuda_gencodes(cc_flag, archs, bare_metal_version): + """ + Adds -gencode flags based on nvcc capabilities: + - sm_80/90 (regular) + - sm_100/120 on CUDA >= 12.8 + - Use 100f on CUDA >= 12.9 (Blackwell family-specific) + - Map requested 110 -> 101 if CUDA < 13.0 (Thor rename) + - Embed PTX for newest arch for forward compatibility + """ + # Always-regular 80 + if "80" in archs: + cc_flag += ["-gencode", "arch=compute_80,code=sm_80"] + + # Hopper 9.0 needs >= 11.8 + if bare_metal_version >= Version("11.8") and "90" in archs: + cc_flag += ["-gencode", "arch=compute_90,code=sm_90"] + + # Blackwell 10.x requires >= 12.8 + if bare_metal_version >= Version("12.8"): + if "100" in archs: + # CUDA 12.9 introduced "family-specific" for Blackwell (100f) + if bare_metal_version >= Version("12.9"): + cc_flag += ["-gencode", "arch=compute_100f,code=sm_100"] + else: + cc_flag += ["-gencode", "arch=compute_100,code=sm_100"] + + if "120" in archs: + # sm_120 is supported in CUDA 12.8/12.9+ toolkits + if bare_metal_version >= Version("12.9"): + cc_flag += ["-gencode", "arch=compute_120f,code=sm_120"] + else: + cc_flag += ["-gencode", "arch=compute_120,code=sm_120"] + + + # Thor rename: 12.9 uses sm_101; 13.0+ uses sm_110 + if "110" in archs: + if bare_metal_version >= Version("13.0"): + cc_flag += ["-gencode", "arch=compute_110f,code=sm_110"] + else: + # Provide Thor support for CUDA 12.9 via sm_101 + if bare_metal_version >= Version("12.8"): + cc_flag += ["-gencode", "arch=compute_101,code=sm_101"] + # else: no Thor support in older toolkits + + # PTX for newest requested arch (forward-compat) + numeric = [a for a in archs if a.isdigit()] + if numeric: + newest = max(numeric, key=int) + cc_flag += ["-gencode", f"arch=compute_{newest},code=compute_{newest}"] + + return cc_flag + + +def get_hip_version(): + return parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+')) + + +def check_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary + # in that case. + warnings.warn( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def check_if_rocm_home_none(global_option: str) -> None: + if ROCM_HOME is not None: + return + # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary + # in that case. + warnings.warn( + f"{global_option} was requested, but hipcc was not found." + ) + + +def detect_hipify_v2(): + try: + from torch.utils.hipify import __version__ + from packaging.version import Version + if Version(__version__) >= Version("2.0.0"): + return True + except Exception as e: + print("failed to detect pytorch hipify version, defaulting to version 1.0.0 behavior") + print(e) + return False + + +def append_nvcc_threads(nvcc_extra_args): + return nvcc_extra_args + ["--threads", NVCC_THREADS] -# TODO(luka): This should be replaced with a fetch_content call in CMakeLists.txt -subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) +def rename_cpp_to_cu(cpp_files): + for entry in cpp_files: + shutil.copy(entry, os.path.splitext(entry)[0] + ".cu") -def is_sccache_available() -> bool: - return which("sccache") is not None +def validate_and_update_archs(archs): + # List of allowed architectures + allowed_archs = ["native", "gfx90a", "gfx950", "gfx942"] -def is_ccache_available() -> bool: - return which("ccache") is not None + # Validate if each element in archs is in allowed_archs + assert all( + arch in allowed_archs for arch in archs + ), f"One of GPU archs of {archs} is invalid or not supported by Flash-Attention" +cmdclass = {} +ext_modules = [] + +# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp +# files included in the source distribution, in case the user compiles from source. +if os.path.isdir(".git"): + if not SKIP_CK_BUILD: + subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"], check=True) + subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"], check=True) +else: + if IS_ROCM: + if not SKIP_CK_BUILD: + assert ( + os.path.exists("csrc/composable_kernel/example/ck_tile/01_fmha/generate.py") + ), "csrc/composable_kernel is missing, please use source distribution or git clone" + else: + assert ( + os.path.exists("csrc/cutlass/include/cutlass/cutlass.h") + ), "csrc/cutlass is missing, please use source distribution or git clone" + + + check_if_cuda_home_none("flash_attn") + # Check, if CUDA11 is installed for compute capability 8.0 + cc_flag = [] + if CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version < Version("11.7"): + raise RuntimeError( + "FlashAttention is only supported on CUDA 11.7 and above. " + "Note: make sure nvcc has a supported version by running nvcc -V." + ) + # Build -gencode (regular + PTX + family-specific 'f' when available) + add_cuda_gencodes(cc_flag, set(cuda_archs()), bare_metal_version) + else: + # No nvcc present; warnings already emitted above + pass + + # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as + # torch._C._GLIBCXX_USE_CXX11_ABI + # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 + if FORCE_CXX11_ABI: + torch._C._GLIBCXX_USE_CXX11_ABI = True + + nvcc_flags = [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + # "--ptxas-options=-v", + # "--ptxas-options=-O2", + # "-lineinfo", + # "-DFLASHATTENTION_DISABLE_BACKWARD", + # "-DFLASHATTENTION_DISABLE_DROPOUT", + # "-DFLASHATTENTION_DISABLE_ALIBI", + # "-DFLASHATTENTION_DISABLE_SOFTCAP", + # "-DFLASHATTENTION_DISABLE_UNEVEN_K", + # "-DFLASHATTENTION_DISABLE_LOCAL", + ] + + compiler_c17_flag=["-O3", "-std=c++17"] + # Add Windows-specific flags + if sys.platform == "win32" and os.getenv('DISTUTILS_USE_SDK') == '1': + nvcc_flags.extend(["-Xcompiler", "/Zc:__cplusplus"]) + compiler_c17_flag=["-O2", "/std:c++17", "/Zc:__cplusplus"] + + ext_modules.append( + CUDAExtension( + name="flash_attn_2_cuda", + sources=[ + "csrc/flash_attn/flash_api.cpp", + "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu", + ], + extra_compile_args={ + "cxx": compiler_c17_flag, + "nvcc": append_nvcc_threads(nvcc_flags + cc_flag), + }, + include_dirs=[ + Path(this_dir) / "csrc" / "flash_attn", + Path(this_dir) / "csrc" / "flash_attn" / "src", + Path(this_dir) / "csrc" / "cutlass" / "include", + ], + ) + ) +elif not SKIP_CUDA_BUILD and IS_ROCM: + print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) + + # Skips CK C++ extension compilation if using Triton Backend + if not SKIP_CK_BUILD: + ck_dir = "csrc/composable_kernel" + def is_ninja_available() -> bool: return which("ninja") is not None + optdim = os.getenv("OPT_DIM", "32,64,128,256") + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_appendkv", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_splitkv", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "bwd", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) def remove_prefix(text, prefix): if text.startswith(prefix): @@ -67,24 +391,65 @@ def remove_prefix(text, prefix): return text -VLLM_TARGET_DEVICE = envs.VLLM_TARGET_DEVICE + if archs != ['native']: + cc_flag = [f"--offload-arch={arch}" for arch in archs] + else: + arch = torch.cuda.get_device_properties("cuda").gcnArchName.split(":")[0] + cc_flag = [f"--offload-arch={arch}"] def _is_cuda() -> bool: has_cuda = torch.version.cuda is not None return VLLM_TARGET_DEVICE == "cuda" and has_cuda + # Check if torch is using hipify v2. Until CK is updated with HIPIFY_V2 macro, + # we must replace the incorrect APIs. + maybe_hipify_v2_flag = [] + if detect_hipify_v2(): + maybe_hipify_v2_flag = ["-DHIPIFY_V2"] + + rename_cpp_to_cu(sources) def _is_hip() -> bool: return (VLLM_TARGET_DEVICE == "cuda" or VLLM_TARGET_DEVICE == "rocm") and torch.version.hip is not None + cc_flag += ["-O3","-std=c++20", + "-DCK_TILE_FMHA_FWD_FAST_EXP2=1", + "-fgpu-flush-denormals-to-zero", + "-DCK_ENABLE_BF16", + "-DCK_ENABLE_BF8", + "-DCK_ENABLE_FP16", + "-DCK_ENABLE_FP32", + "-DCK_ENABLE_FP64", + "-DCK_ENABLE_FP8", + "-DCK_ENABLE_INT8", + "-DCK_USE_XDL", + "-DUSE_PROF_API=1", + # "-DFLASHATTENTION_DISABLE_BACKWARD", + "-D__HIP_PLATFORM_HCC__=1"] def is_freethreaded(): return bool(sysconfig.get_config_var("Py_GIL_DISABLED")) - -class CMakeExtension(Extension): + # Imitate https://github.com/ROCm/composable_kernel/blob/c8b6b64240e840a7decf76dfaa13c37da5294c4a/CMakeLists.txt#L190-L214 + hip_version = get_hip_version() + if hip_version > Version('5.5.00000'): + cc_flag += ["-mllvm", "--lsr-drop-solution=1"] + if hip_version > Version('5.7.23302'): + cc_flag += ["-fno-offload-uniform-block"] + if hip_version > Version('6.1.40090'): + cc_flag += ["-mllvm", "-enable-post-misched=0"] + if hip_version > Version('6.2.41132'): + cc_flag += ["-mllvm", "-amdgpu-early-inline-all=true", + "-mllvm", "-amdgpu-function-calls=false"] + if hip_version > Version('6.2.41133') and hip_version < Version('6.3.00000'): + cc_flag += ["-mllvm", "-amdgpu-coerce-illegal-types=1"] + + extra_compile_args = { + "cxx": ["-O3", "-std=c++20"] + generator_flag + maybe_hipify_v2_flag, + "nvcc": cc_flag + generator_flag + maybe_hipify_v2_flag, + } def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None: super().__init__(name, sources=[], py_limited_api=not is_freethreaded(), **kwa) @@ -286,6 +651,26 @@ def get_version() -> str: version += f"+cu{cuda_version_str}" return version + nvcc_threads = max(1, int(NVCC_THREADS)) + + # calculate the maximum allowed NUM_JOBS based on cores + max_num_jobs_cores = max(1, os.cpu_count() // 2) + + # calculate the maximum allowed NUM_JOBS based on free memory + free_memory_gb = psutil.virtual_memory().available / (1024 ** 3) # free memory in GB + # Assume worst-case peak observed memory usage of ~5GB per NVCC thread. + # Limit: peak_threads = max_jobs * nvcc_threads and peak_threads * 5GB <= free_memory. + max_num_jobs_memory = max(1, int(free_memory_gb / (5 * nvcc_threads))) + + # pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation + max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory)) + print( + f"Auto set MAX_JOBS to `{max_jobs}`, NVCC_THREADS to `{nvcc_threads}`. " + "If you see memory pressure, please use a lower `MAX_JOBS=N` or `NVCC_THREADS=N` value." + ) + os.environ["MAX_JOBS"] = str(max_jobs) + + super().__init__(*args, **kwargs) ext_modules.append(CMakeExtension(name="vllm_flash_attn._vllm_fa2_C")) ext_modules.append(CMakeExtension(name="vllm_flash_attn._vllm_fa3_C")) diff --git a/tests/cute/benchmark_block_sparsity.py b/tests/cute/benchmark_block_sparsity.py new file mode 100644 index 00000000000..ed6bfad2daa --- /dev/null +++ b/tests/cute/benchmark_block_sparsity.py @@ -0,0 +1,393 @@ +""" +Comparative benchmark: CuTe DSL vs Native PyTorch block sparsity computation. +""" + +import torch +from dataclasses import dataclass +from typing import Callable, Optional, List +from tabulate import tabulate +from tqdm import tqdm +import itertools + +from cutlass.cute.runtime import from_dlpack +from cutlass.cute.testing import benchmark as cute_benchmark +import cutlass.cute as cute +from flash_attn.cute.compute_block_sparsity import BlockSparsityKernel +from flash_attn.cute.block_sparsity import BlockSparseTensors +from mask_mod_definitions import ( + get_mask_pair, + random_doc_id_tensor, + flex_document_mask, + cute_document_mask, +) + +from torch.nn.attention.flex_attention import create_block_mask +from triton.testing import do_bench + +# Configure torch.compile cache to prevent memory buildup +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class BenchmarkConfig: + """Configuration for a benchmark run.""" + + batch_size: int + num_heads: int + seqlen_q: int + seqlen_k: int + mask_name: str + tile_m: int = 128 + tile_n: int = 128 + use_fast_sampling: bool = False + aux_tensors_cute: Optional[list] = None + + +@dataclass(frozen=True) +class BenchmarkResult: + """Result of a single benchmark run.""" + + config: BenchmarkConfig + cute_time_ms: Optional[float] + pytorch_time_ms: Optional[float] + error_message: Optional[str] = None + + +def benchmark_pytorch_block_sparsity( + config: BenchmarkConfig, + mask_fn: Callable, +) -> Optional[float]: + """ + Benchmark PyTorch block mask creation (compiled). + Returns: creation_time_ms + """ + device = "cuda" + + try: + cbm = torch.compile(create_block_mask) + + def run_benchmark(): + return cbm( + mask_fn, + config.batch_size, + config.num_heads, + config.seqlen_q, + config.seqlen_k, + device=device, + ) + + creation_time_ms = do_bench(run_benchmark, warmup=10, rep=100) + + return creation_time_ms + + except Exception as e: + print(f"PyTorch benchmark failed ({config.mask_name}): {e}") + import traceback + + traceback.print_exc() + return None + + +def benchmark_cute_block_sparsity( + config: BenchmarkConfig, + mask_fn: Callable, +) -> Optional[float]: + """ + Benchmark CuTe block sparsity kernel. + Returns: creation_time_ms + """ + device = "cuda" + + try: + num_m_blocks = (config.seqlen_q + config.tile_m - 1) // config.tile_m + num_n_blocks = (config.seqlen_k + config.tile_n - 1) // config.tile_n + + mask_block_cnt = torch.zeros( + (config.batch_size, config.num_heads, num_m_blocks), + device=device, + dtype=torch.int32, + ) + mask_block_idx = torch.zeros( + (config.batch_size, config.num_heads, num_m_blocks, num_n_blocks), + device=device, + dtype=torch.int32, + ) + full_block_cnt = torch.zeros( + (config.batch_size, config.num_heads, num_m_blocks), + device=device, + dtype=torch.int32, + ) + full_block_idx = torch.zeros( + (config.batch_size, config.num_heads, num_m_blocks, num_n_blocks), + device=device, + dtype=torch.int32, + ) + + # Convert to CuTe tensors + mask_cnt_cute = from_dlpack( + mask_block_cnt.detach(), assumed_align=4 + ).mark_layout_dynamic(leading_dim=2) + mask_idx_cute = from_dlpack( + mask_block_idx.detach(), assumed_align=4 + ).mark_layout_dynamic(leading_dim=3) + full_cnt_cute = from_dlpack( + full_block_cnt.detach(), assumed_align=4 + ).mark_layout_dynamic(leading_dim=2) + full_idx_cute = from_dlpack( + full_block_idx.detach(), assumed_align=4 + ).mark_layout_dynamic(leading_dim=3) + + blocksparse_tensors = BlockSparseTensors( + mask_block_cnt=mask_cnt_cute, + mask_block_idx=mask_idx_cute, + full_block_cnt=full_cnt_cute, + full_block_idx=full_idx_cute, + ) + + # Create kernel + use_aux = ( + config.aux_tensors_cute is not None and len(config.aux_tensors_cute) > 0 + ) + kernel = BlockSparsityKernel( + mask_mod=mask_fn, + tile_mn=(config.tile_m, config.tile_n), + compute_full_blocks=True, + use_aux_tensors=use_aux, + use_fast_sampling=config.use_fast_sampling, + ) + + # Compile kernel + compiled_kernel = cute.compile( + kernel, + blocksparse_tensors, + config.seqlen_q, + config.seqlen_k, + config.aux_tensors_cute, + ) + + def generate_tensors(): + from cutlass.cute.testing import JitArguments + + return JitArguments( + blocksparse_tensors, + config.seqlen_q, + config.seqlen_k, + config.aux_tensors_cute, + ) + + creation_time_us = cute_benchmark( + compiled_kernel, + workspace_generator=generate_tensors, + warmup_iterations=10, + iterations=100, + ) + + torch.cuda.synchronize(device) + creation_time_ms = creation_time_us / 1000.0 + + return creation_time_ms + + except Exception as e: + print(f"CuTe benchmark failed: {e}") + return None + + +def run_benchmark( + config: BenchmarkConfig, + pytorch_mask_fn: Callable, + cute_mask_fn: Callable, +) -> BenchmarkResult: + """Run benchmarks for both implementations.""" + + print( + f"Benchmarking {config.mask_name} - B={config.batch_size}, H={config.num_heads}, " + f"M={config.seqlen_q}, N={config.seqlen_k}" + ) + + # Benchmark PyTorch + pytorch_time = benchmark_pytorch_block_sparsity(config, pytorch_mask_fn) + + # Benchmark CuTe + cute_time = benchmark_cute_block_sparsity(config, cute_mask_fn) + + return BenchmarkResult( + config=config, + cute_time_ms=cute_time, + pytorch_time_ms=pytorch_time, + ) + + +def generate_configs( + batch_sizes: List[int], + num_heads: List[int], + seqlens: List[int], + mask_names: List[str], +) -> List[BenchmarkConfig]: + """Generate all benchmark configurations.""" + configs = [] + for B, H, S, mask_name in itertools.product( + batch_sizes, num_heads, seqlens, mask_names + ): + configs.append( + BenchmarkConfig( + batch_size=B, + num_heads=H, + seqlen_q=S, + seqlen_k=S, + mask_name=mask_name, + ) + ) + return configs + + +def print_results(results: List[BenchmarkResult]): + successful_results = [ + r + for r in results + if r.cute_time_ms is not None and r.pytorch_time_ms is not None + ] + + if not successful_results: + print("No successful benchmark results to display") + return + + headers = [ + "B", + "H", + "M", + "N", + "Mask Type", + "CuTe Time (ms)", + "PyTorch Time (ms)", + "Speedup", + ] + + rows = [] + for result in successful_results: + speedup = ( + result.pytorch_time_ms / result.cute_time_ms + if result.cute_time_ms > 0 + else 0 + ) + + rows.append( + [ + result.config.batch_size, + result.config.num_heads, + result.config.seqlen_q, + result.config.seqlen_k, + result.config.mask_name, + f"{result.cute_time_ms:.4f}", + f"{result.pytorch_time_ms:.4f}", + f"{speedup:.2f}x", + ] + ) + + # Sort by batch, head, seqlen, then mask type + rows.sort(key=lambda x: (x[0], x[1], x[2], x[4])) + + print("\n" + "=" * 100) + print("CuTe DSL vs PyTorch Block Sparsity Benchmark Results") + print("=" * 100) + print(tabulate(rows, headers=headers, tablefmt="github")) + print("=" * 100) + + +def main(): + """Run the comparative benchmark.""" + + # Configuration + batch_sizes = [1, 4, 8] + num_heads = [8, 16] + seqlens = [1024, 2048, 4096, 8192] + mask_names = [ + "causal", + "sliding_window", + "prefix_lm", + "dilated_sliding_window", + "document", + ] + + device = "cuda" + max_seqlen = max(seqlens) + max_batch = max(batch_sizes) + max_heads = max(num_heads) + + # Create document IDs using the helper from mask_definitions + doc_ids = random_doc_id_tensor(max_heads, max_batch, max_seqlen, device=device) + doc_ids_cute = from_dlpack(doc_ids.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=2 + ) + + # Generate base configurations + base_configs = generate_configs(batch_sizes, num_heads, seqlens, mask_names) + + # Update configs with aux tensors for document masking + configs = [] + for config in base_configs: + if config.mask_name == "document": + # Add aux tensors for document masking + configs.append( + BenchmarkConfig( + batch_size=config.batch_size, + num_heads=config.num_heads, + seqlen_q=config.seqlen_q, + seqlen_k=config.seqlen_k, + mask_name=config.mask_name, + tile_m=config.tile_m, + tile_n=config.tile_n, + use_fast_sampling=False, + aux_tensors_cute=[doc_ids_cute], + ) + ) + else: + configs.append(config) + + # Run benchmarks + results = [] + print(f"Running {len(configs)} benchmark configurations...") + for config in tqdm(configs, desc="Benchmarking"): + try: + # Get mask pair from mask_definitions + mask_kwargs = {} + if config.mask_name == "sliding_window": + mask_kwargs["window_size"] = 128 # Default window size + + cute_mask_fn, pytorch_mask_fn = get_mask_pair( + config.mask_name, + seqlen_q=config.seqlen_q, + seqlen_k=config.seqlen_k, + **mask_kwargs, + ) + + # For document masking, create wrapper that captures doc_ids + if config.mask_name == "document": + # PyTorch wrapper + def pytorch_mask_fn(b, h, q, kv): + return flex_document_mask(b, h, q, kv, doc_ids) + + # CuTe wrapper - reuse cute_document_mask with aux_tensors + cute_mask_fn = cute_document_mask + + result = run_benchmark(config, pytorch_mask_fn, cute_mask_fn) + results.append(result) + + except Exception as e: + print(f"Failed to run config {config}: {e}") + results.append( + BenchmarkResult( + config=config, + cute_time_ms=None, + pytorch_time_ms=None, + error_message=str(e), + ) + ) + finally: + torch.cuda.empty_cache() + torch._dynamo.reset() + + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/tests/cute/benchmark_mask_mod.py b/tests/cute/benchmark_mask_mod.py new file mode 100644 index 00000000000..ecf9ff4ea68 --- /dev/null +++ b/tests/cute/benchmark_mask_mod.py @@ -0,0 +1,686 @@ +""" +FlashAttention benchmarking script with Flex Attention-style +mask mod support and varlen sequences. +""" + +from dataclasses import dataclass +import math +from typing import Any, Dict, Optional, Tuple + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack +import numpy as np +import torch + +from flash_attn.cute.flash_fwd import FlashAttentionForwardSm90 +from mask_mod_definitions import ( + get_mask_pair, + random_doc_id_tensor, +) +from flash_attn.cute.block_sparsity import ( + compute_block_sparsity, + BlockSparseTensorsTorch, + to_cute_block_sparse_tensors, +) + + +@dataclass +class BenchmarkConfig: + """Benchmark configuration""" + + # Model parameters + headdim: int + headdim_v: int + nheads: int + nheads_kv: int + dtype: torch.dtype + + # Sequence parameters + batch_size: int = 2 + seqlen_q: int = 8192 + seqlen_k: int = 8192 + + # Varlen parameters + use_varlen: bool = False + min_seqlen_q: Optional[int] = None # If None, use seqlen_q // 2 + max_seqlen_q: Optional[int] = None # If None, use seqlen_q + min_seqlen_k: Optional[int] = None # If None, use seqlen_k // 2 + max_seqlen_k: Optional[int] = None # If None, use seqlen_k + + # Mask parameters + use_mask_mod: bool = True + mask_mod_name: str = "causal" + has_aux_tensors: bool = mask_mod_name == "document" + + # Sliding window parameter (used when mask_mod_name == "sliding_window") + window_size: int = 128 + + # Attention parameters + causal: bool = False + is_local: bool = False + window_left: Optional[int] = 128 # For base Flash Attention local + window_right: Optional[int] = 0 # For base Flash Attention local + softcap: Optional[float] = None + use_learnable_sink: bool = False + + # Kernel configuration + tile_m: int = 128 + tile_n: int = 128 + num_stages: int = 2 + num_threads: int = 384 + intra_wg_overlap: bool = True + mma_pv_is_rs: bool = True + + # Benchmark parameters + warmup_iters: int = 10 + benchmark_iters: int = 25 + verbose: bool = False + seed: int = 42 + + +class FlashAttentionBenchmark: + def __init__(self, config: BenchmarkConfig): + self.config = config + + torch.manual_seed(config.seed) + np.random.seed(config.seed) + + # Verify SM90 compute capability + compute_capability = torch.cuda.get_device_capability() + assert compute_capability >= (9, 0), ( + f"Requires SM90+, got SM{compute_capability[0]}{compute_capability[1]}" + ) + # causal overrides use_mask_mod + if config.causal: + config.use_mask_mod = False + + if config.use_mask_mod: + self.mask_mod_cute, self.mask_mod_flex = get_mask_pair( + config.mask_mod_name, + seqlen_q=config.seqlen_q, + seqlen_k=config.seqlen_k, + window_size=config.window_size, + ) + else: + self.mask_mod_cute = None + self.mask_mod_flex = None + + self._validate_config() + + def _validate_config(self): + config = self.config + + assert config.headdim <= 256, "headdim must be <= 256" + assert config.headdim_v <= 256, "headdim_v must be <= 256" + assert config.nheads % config.nheads_kv == 0, "nheads must be divisible by nheads_kv" + + alignment = 16 // config.dtype.itemsize + assert config.headdim % alignment == 0, f"headdim must be divisible by {alignment}" + assert config.headdim_v % alignment == 0, f"headdim_v must be divisible by {alignment}" + + # Validate is_local configuration + if config.is_local: + assert config.window_left is not None or config.window_right is not None, ( + "When is_local=True, at least one of window_left or window_right must be set" + ) + assert not config.use_mask_mod, ( + "Cannot use both is_local and use_mask_mod simultaneously" + ) + assert not config.causal, "Cannot use both is_local and causal simultaneously" + + # Validate mask_mod configuration + if config.use_mask_mod and config.mask_mod_name == "sliding_window": + assert config.window_size > 0, ( + "window_size must be positive when using sliding_window mask" + ) + + def _generate_varlen_seqlens(self, min_len: int, max_len: int) -> Tuple[torch.Tensor, int]: + """Generate random sequence lengths and compute cumulative lengths.""" + seqlens = torch.randint( + min_len, max_len + 1, (self.config.batch_size,), dtype=torch.int32, device="cuda" + ) + cu_seqlens = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device="cuda"), + torch.cumsum(seqlens, dtype=torch.int32, dim=0), + ] + ) + + total_tokens = cu_seqlens[-1].item() + return cu_seqlens, total_tokens + + def _create_tensors(self) -> Dict[str, torch.Tensor]: + config = self.config + device = "cuda" + + if config.use_varlen: + # Set defaults for varlen range + min_q = config.min_seqlen_q if config.min_seqlen_q is not None else config.seqlen_q // 2 + max_q = config.max_seqlen_q if config.max_seqlen_q is not None else config.seqlen_q + min_k = config.min_seqlen_k if config.min_seqlen_k is not None else config.seqlen_k // 2 + max_k = config.max_seqlen_k if config.max_seqlen_k is not None else config.seqlen_k + + # Generate cu_seqlens + cu_seqlens_q, total_q = self._generate_varlen_seqlens(min_q, max_q) + cu_seqlens_k, total_k = self._generate_varlen_seqlens(min_k, max_k) + + # Varlen shape: (total_tokens, nheads, headdim) + q = torch.randn( + total_q, config.nheads, config.headdim, dtype=config.dtype, device=device + ) + k = torch.randn( + total_k, config.nheads_kv, config.headdim, dtype=config.dtype, device=device + ) + v = torch.randn( + total_k, config.nheads_kv, config.headdim_v, dtype=config.dtype, device=device + ) + out = torch.empty( + total_q, config.nheads, config.headdim_v, dtype=config.dtype, device=device + ) + lse = torch.empty(config.nheads, total_q, dtype=torch.float32, device=device) + + tensors = { + "q": q.contiguous(), + "k": k.contiguous(), + "v": v.contiguous(), + "out": out.contiguous(), + "lse": lse.contiguous(), + "cu_seqlens_q": cu_seqlens_q.contiguous(), + "cu_seqlens_k": cu_seqlens_k.contiguous(), + } + + if config.verbose: + print(f"Varlen: total_q={total_q}, total_k={total_k}") + print(f"Q seqlens: {cu_seqlens_q[1:] - cu_seqlens_q[:-1]}") + print(f"K seqlens: {cu_seqlens_k[1:] - cu_seqlens_k[:-1]}") + else: + # Standard shape: (batch, seqlen, nheads, headdim) + q = torch.randn( + config.batch_size, + config.seqlen_q, + config.nheads, + config.headdim, + dtype=config.dtype, + device=device, + ) + k = torch.randn( + config.batch_size, + config.seqlen_k, + config.nheads_kv, + config.headdim, + dtype=config.dtype, + device=device, + ) + v = torch.randn( + config.batch_size, + config.seqlen_k, + config.nheads_kv, + config.headdim_v, + dtype=config.dtype, + device=device, + ) + out = torch.empty( + config.batch_size, + config.seqlen_q, + config.nheads, + config.headdim_v, + dtype=config.dtype, + device=device, + ) + lse = torch.empty( + config.batch_size, + config.nheads, + config.seqlen_q, + dtype=torch.float32, + device=device, + ) + + tensors = { + "q": q.contiguous(), + "k": k.contiguous(), + "v": v.contiguous(), + "out": out.contiguous(), + "lse": lse.contiguous(), + } + + if config.use_learnable_sink: + learnable_sink = torch.rand(config.nheads, dtype=torch.bfloat16, device=device) + + tensors["learnable_sink"] = learnable_sink.contiguous() + + # Compute block sparsity when using mask_mod + if config.use_mask_mod: + if config.mask_mod_name == "document": + doc_id = random_doc_id_tensor( + config.batch_size, config.nheads, config.seqlen_q, device=device + ) + tensors["aux_tensors"] = [doc_id.contiguous()] + full_cnt, full_idx, mask_cnt, mask_idx = compute_block_sparsity( + config=self.config, + mask_mod_flex=self.mask_mod_flex, + device=device, + cu_seqlens_q=tensors.get("cu_seqlens_q"), + cu_seqlens_k=tensors.get("cu_seqlens_k"), + aux_tensors=tensors.get("aux_tensors"), + ) + + if all(t is not None for t in [full_cnt, full_idx, mask_cnt, mask_idx]): + tensors["block_sparse_tensors"] = BlockSparseTensorsTorch( + mask_block_cnt=mask_cnt.contiguous(), + mask_block_idx=mask_idx.contiguous(), + full_block_cnt=full_cnt.contiguous(), + full_block_idx=full_idx.contiguous(), + ) + + if config.verbose: + total_full = full_cnt.sum().item() + total_partial = mask_cnt.sum().item() + + if config.use_varlen: + # Compute max possible blocks across all sequences + max_blocks = 0 + for i in range(config.batch_size): + seq_len_q = ( + tensors["cu_seqlens_q"][i + 1] - tensors["cu_seqlens_q"][i] + ).item() + seq_len_k = ( + tensors["cu_seqlens_k"][i + 1] - tensors["cu_seqlens_k"][i] + ).item() + n_blocks_q = (seq_len_q + config.tile_m - 1) // config.tile_m + n_blocks_k = (seq_len_k + config.tile_n - 1) // config.tile_n + max_blocks += n_blocks_q * n_blocks_k * config.nheads + else: + n_blocks_k = (config.seqlen_k + config.tile_n - 1) // config.tile_n + n_blocks_q = (config.seqlen_q + config.tile_m - 1) // config.tile_m + max_blocks = n_blocks_k * n_blocks_q * config.nheads * config.batch_size + + skipped = max_blocks - total_full - total_partial + print( + f"Block stats: Full={total_full}, Partial={total_partial}, " + f"Skipped={skipped}/{max_blocks}" + ) + + return tensors + + def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple]: + config = self.config + + dtype_map = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, + } + cute_dtype = dtype_map[config.dtype] + + qhead_per_kvhead = config.nheads // config.nheads_kv + kernel = FlashAttentionForwardSm90( + cute_dtype, + config.headdim, + config.headdim_v, + qhead_per_kvhead, + is_causal=config.causal, + is_local=config.is_local, + pack_gqa=False, + tile_m=config.tile_m, + tile_n=config.tile_n, + num_stages=config.num_stages, + num_threads=config.num_threads, + intra_wg_overlap=config.intra_wg_overlap, + mma_pv_is_rs=config.mma_pv_is_rs, + mask_mod=self.mask_mod_cute, + Q_in_regs=False, + has_aux_tensors=config.has_aux_tensors, + ) + + softmax_scale = 1.0 / math.sqrt(config.headdim) + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + # Convert tensors to cute + q_cute = from_dlpack(tensors["q"].detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=tensors["q"].ndim - 1 + ) + k_cute = from_dlpack(tensors["k"].detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=tensors["k"].ndim - 1 + ) + v_cute = from_dlpack(tensors["v"].detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=tensors["v"].ndim - 1 + ) + out_cute = from_dlpack(tensors["out"].detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=tensors["out"].ndim - 1 + ) + lse_cute = from_dlpack(tensors["lse"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=tensors["lse"].ndim - 1 + ) + + # Varlen tensors + cu_seqlens_q_cute = ( + from_dlpack(tensors["cu_seqlens_q"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + if "cu_seqlens_q" in tensors + else None + ) + cu_seqlens_k_cute = ( + from_dlpack(tensors["cu_seqlens_k"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + if "cu_seqlens_k" in tensors + else None + ) + learnable_sink_cute = ( + from_dlpack(tensors["learnable_sink"].detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + if "learnable_sink" in tensors + else None + ) + + blocksparse_tensors_cute = ( + to_cute_block_sparse_tensors(tensors["block_sparse_tensors"]) + if "block_sparse_tensors" in tensors + else None + ) + + if "aux_tensors" in tensors: + aux_tensors_cute = [] + for i in range(len(tensors["aux_tensors"])): + buf = from_dlpack(tensors["aux_tensors"][i].detach(), assumed_align=4) + aux_tensors_cute.append(buf.mark_layout_dynamic(leading_dim=2)) + + else: + aux_tensors_cute = None + + # Window parameters for is_local + window_left_cute = ( + cutlass.Int32(config.window_left) if config.window_left is not None else None + ) + window_right_cute = ( + cutlass.Int32(config.window_right) if config.window_right is not None else None + ) + + compiled = cute.compile( + kernel, + q_cute, + k_cute, + v_cute, + out_cute, + lse_cute, + softmax_scale, + current_stream, + cu_seqlens_q_cute, + cu_seqlens_k_cute, + None, # seqused_q + None, # seqused_k + None, # page_table + window_left_cute, + window_right_cute, + learnable_sink_cute, + blocksparse_tensors_cute, + aux_tensors_cute, + # None, + ) + + args = ( + q_cute, + k_cute, + v_cute, + out_cute, + lse_cute, + softmax_scale, + current_stream, + cu_seqlens_q_cute, + cu_seqlens_k_cute, + None, + None, + None, + window_left_cute, + window_right_cute, + learnable_sink_cute, + blocksparse_tensors_cute, + aux_tensors_cute, + # None, + ) + + return compiled, args + + def _calculate_flops(self, tensors: Dict[str, torch.Tensor]) -> float: + config = self.config + + # Estimate sparsity for known mask patterns + if config.is_local: + # Local attention with window_left and window_right + window_left = config.window_left if config.window_left is not None else 0 + window_right = config.window_right if config.window_right is not None else 0 + total_window = window_left + window_right + 1 # +1 for current position + sparsity_ratio = min(1.0, total_window / config.seqlen_k) + elif config.use_mask_mod: + if config.mask_mod_name in ["identity", "identity_partial"]: + sparsity_ratio = 1.0 + elif config.mask_mod_name in ["causal", "block_causal"]: + sparsity_ratio = 0.5 + elif config.mask_mod_name == "sliding_window": + # Use configured window size + sparsity_ratio = min(1.0, config.window_size / config.seqlen_k) + elif config.mask_mod_name == "block_diagonal": + block_size = 64 + num_blocks = (config.seqlen_k + block_size - 1) // block_size + sparsity_ratio = 1.0 / num_blocks if num_blocks > 1 else 1.0 + elif config.mask_mod_name == "document": + vals = tensors["aux_tensors"][0] + val_mask = torch.ones_like(vals, dtype=torch.bool) + val_mask[..., 1:] = vals[..., 1:] != vals[..., :-1] + total = torch.where(val_mask, vals.square(), 0).sum() + sparsity_ratio = total / (config.seqlen_q * config.seqlen_k) + else: + sparsity_ratio = 1.0 + elif config.causal: + sparsity_ratio = 0.5 + else: + sparsity_ratio = 1.0 + + if config.use_varlen: + # Compute FLOPs per sequence and sum + total_flops = 0 + cu_q = tensors["cu_seqlens_q"] + cu_k = tensors["cu_seqlens_k"] + for i in range(config.batch_size): + seq_len_q = (cu_q[i + 1] - cu_q[i]).item() + seq_len_k = (cu_k[i + 1] - cu_k[i]).item() + + # Adjust sparsity for local attention in varlen case + if config.is_local: + window_left = config.window_left if config.window_left is not None else 0 + window_right = config.window_right if config.window_right is not None else 0 + total_window = window_left + window_right + 1 + seq_sparsity = min(1.0, total_window / seq_len_k) + elif config.use_mask_mod and config.mask_mod_name == "sliding_window": + seq_sparsity = min(1.0, config.window_size / seq_len_k) + else: + seq_sparsity = sparsity_ratio + + num_cells = int(seq_len_q * seq_len_k * seq_sparsity) + + if config.headdim == config.headdim_v: + flops_this_seq = 4 * config.nheads * num_cells * config.headdim + else: + flops_this_seq = ( + 2 * config.nheads * num_cells * config.headdim + + 2 * config.nheads * num_cells * config.headdim_v + ) + total_flops += flops_this_seq + return total_flops + else: + num_cells = int(config.seqlen_q * config.seqlen_k * sparsity_ratio) + if config.headdim == config.headdim_v: + flops_per_batch = 4 * config.nheads * num_cells * config.headdim + else: + flops_per_batch = ( + 2 * config.nheads * num_cells * config.headdim + + 2 * config.nheads * num_cells * config.headdim_v + ) + return flops_per_batch * config.batch_size + + def benchmark(self) -> Dict[str, Any]: + config = self.config + + tensors = self._create_tensors() + compiled_kernel, args = self._compile_kernel(tensors) + + # Warmup + for _ in range(config.warmup_iters): + compiled_kernel(*args) + torch.cuda.synchronize() + + # Benchmark + times = [] + for _ in range(config.benchmark_iters): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + compiled_kernel(*args) + end.record() + torch.cuda.synchronize() + + times.append(start.elapsed_time(end)) + + times_tensor = torch.tensor(times) + mean_time = times_tensor.mean().item() + std_time = times_tensor.std().item() if len(times) > 1 else 0.0 + + total_flops = self._calculate_flops(tensors) + tflops = total_flops / (mean_time * 1e-3) / 1e12 + + # Bandwidth calculation + bytes_per_element = config.dtype.itemsize + if config.use_varlen: + total_q = tensors["q"].shape[0] + total_k = tensors["k"].shape[0] + memory_accessed = ( + total_q * config.nheads * config.headdim * bytes_per_element + + total_k * config.nheads_kv * config.headdim * bytes_per_element + + total_k * config.nheads_kv * config.headdim_v * bytes_per_element + + total_q * config.nheads * config.headdim_v * bytes_per_element + ) + else: + memory_accessed = ( + config.batch_size + * config.seqlen_q + * config.nheads + * config.headdim + * bytes_per_element + + config.batch_size + * config.seqlen_k + * config.nheads_kv + * config.headdim + * bytes_per_element + + config.batch_size + * config.seqlen_k + * config.nheads_kv + * config.headdim_v + * bytes_per_element + + config.batch_size + * config.seqlen_q + * config.nheads + * config.headdim_v + * bytes_per_element + ) + bandwidth_gbps = memory_accessed / (mean_time * 1e-3) / 1e9 + + results = { + "mean_time_ms": mean_time, + "std_time_ms": std_time, + "tflops": tflops, + "bandwidth_gbps": bandwidth_gbps, + } + + if config.verbose: + self._print_results(results) + + return results + + def _print_results(self, results: Dict[str, Any]): + config = self.config + + # Basic configuration + if config.use_varlen: + print( + f"Shape: B={config.batch_size} (varlen), HD={config.headdim}, " + f"NH={config.nheads}, NKV={config.nheads_kv}" + ) + else: + print( + f"Shape: B={config.batch_size}, Q={config.seqlen_q}, K={config.seqlen_k}, " + f"HD={config.headdim}, NH={config.nheads}, NKV={config.nheads_kv}" + ) + + # Attention pattern + attn_info = [] + if config.causal: + attn_info.append("causal") + if config.is_local: + window_info = f"local(L={config.window_left},R={config.window_right})" + attn_info.append(window_info) + if config.use_mask_mod: + if config.mask_mod_name == "sliding_window": + attn_info.append(f"mask_mod={config.mask_mod_name}(w={config.window_size})") + else: + attn_info.append(f"mask_mod={config.mask_mod_name}") + if config.use_varlen: + attn_info.append("varlen") + if attn_info: + print(f"Attention: {', '.join(attn_info)}") + + # Performance metrics + print(f"Time: {results['mean_time_ms']:.3f} ± {results['std_time_ms']:.3f} ms") + print(f"Throughput: {results['tflops']:.2f} TFLOPS") + print(f"Bandwidth: {results['bandwidth_gbps']:.1f} GB/s") + + +if __name__ == "__main__": + B = 2 + config = BenchmarkConfig( + headdim=128, + headdim_v=128, + nheads=16, + nheads_kv=16, + dtype=torch.bfloat16, + batch_size=B, + # batch_size=1, + seqlen_q=8192, + # seqlen_q=128, + seqlen_k=8192, + # seqlen_k=192, + use_varlen=False, + use_mask_mod=False, + mask_mod_name="causal", + window_size=128, # Configurable window size for mask_mod + use_learnable_sink=False, + causal=True, + is_local=False, + verbose=True, + ) + + # Example 2: Base Flash Attention Local + # config = BenchmarkConfig( + # headdim=64, + # headdim_v=64, + # nheads=64, + # nheads_kv=8, + # dtype=torch.bfloat16, + # batch_size=2, + # seqlen_q=8192, + # seqlen_k=8192, + # use_varlen=False, + # use_mask_mod=False, + # causal=False, + # is_local=True, + # window_left=128, # Left window size for base local attention + # window_right=0, # Right window size for base local attention + # verbose=True, + # ) + + benchmark = FlashAttentionBenchmark(config) + results = benchmark.benchmark() diff --git a/tests/cute/mask_mod_definitions.py b/tests/cute/mask_mod_definitions.py new file mode 100644 index 00000000000..0820c6f5271 --- /dev/null +++ b/tests/cute/mask_mod_definitions.py @@ -0,0 +1,332 @@ +from typing import Callable, Optional + +import random +import math + +import cutlass +import cutlass.cute as cute +import torch + +from flash_attn.cute import utils +from flash_attn.cute.block_sparsity import fast_sampling + + +# ============================================================================= +# CuTe mask_mod functions (for kernel compilation) +# All use signature: (batch, head, m_idx, n_idx, seqlen_info, aux_tensors) +# ============================================================================= + +# ============================================================================= +# mask_mod functions that don't use global indices +# ============================================================================= + + +@fast_sampling +@cute.jit +def cute_causal_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors: None, +) -> cute.TensorSSA: + offset = seqlen_info.seqlen_k - seqlen_info.seqlen_q + offset_ssa = utils.scalar_to_ssa(offset, cutlass.Int32) + return n_idx <= (m_idx + offset_ssa) + + +def get_cute_causal_mask(offset: int): + return cute_causal_mask + + +def get_cute_block_causal_mask(offset: int): + @fast_sampling + @cute.jit + def _cute_block_causal_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors: None, + ) -> cute.TensorSSA: + offset_ssa = utils.scalar_to_ssa(offset, cutlass.Int32) + return n_idx <= (m_idx + offset_ssa) + + return _cute_block_causal_mask + + +def get_cute_sliding_window_mask(window_left: int, window_right: int, offset: int): + @fast_sampling + @cute.jit + def _cute_sliding_window_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, + ) -> cute.TensorSSA: + offset = seqlen_info.seqlen_k - seqlen_info.seqlen_q + offset_ssa = utils.scalar_to_ssa(offset, cutlass.Int32) + window_left_ssa = utils.scalar_to_ssa(window_left, cutlass.Int32) + window_right_ssa = utils.scalar_to_ssa(window_right, cutlass.Int32) + center = m_idx + offset_ssa + lower = center - window_left_ssa + upper = center + window_right_ssa + return (n_idx >= lower) & (n_idx <= upper) + + return _cute_sliding_window_mask + + +@fast_sampling +@cute.jit +def cute_block_diagonal_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, +) -> cute.TensorSSA: + block_size_ssa = utils.scalar_to_ssa(128, cutlass.Int32) + return (m_idx // block_size_ssa) == (n_idx // block_size_ssa) + + +@cute.jit +def cute_mini_causal_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, +) -> cute.TensorSSA: + tile_size_ssa = utils.scalar_to_ssa(128, cutlass.Int32) + m_mod = m_idx % tile_size_ssa + n_mod = n_idx % tile_size_ssa + return m_mod >= n_mod + + +@fast_sampling +@cute.jit +def cute_prefix_lm_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, +) -> cute.TensorSSA: + """Prefix LM mask: first 512 tokens attend bidirectionally, rest use causal masking.""" + prefix_size_ssa = utils.scalar_to_ssa(512, cutlass.Int32) + both_in_prefix = (m_idx < prefix_size_ssa) & (n_idx < prefix_size_ssa) + causal_part = m_idx >= n_idx + return both_in_prefix | causal_part + + +@cute.jit +def cute_dilated_sliding_window_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, +) -> cute.TensorSSA: + """Dilated sliding window: every other position in a 256-position window.""" + window_size_ssa = utils.scalar_to_ssa(256, cutlass.Int32) + dilation_ssa = utils.scalar_to_ssa(2, cutlass.Int32) + in_window = (m_idx >= n_idx) & (m_idx - n_idx < window_size_ssa) + dilated = ((m_idx - n_idx) % dilation_ssa) == utils.scalar_to_ssa(0, cutlass.Int32) + return in_window & dilated + + +@fast_sampling +@cute.jit +def cute_document_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors: list, +) -> cute.TensorSSA: + doc_id = aux_tensors[0] + m_doc = utils.scalar_to_ssa(doc_id[batch[0], head[0], m_idx[0]], cutlass.Int32) + n_doc = utils.scalar_to_ssa(doc_id[batch[0], head[0], n_idx[0]], cutlass.Int32) + return m_doc == n_doc + + +@fast_sampling +@cute.jit +def cute_ima_mask( + batch: cute.TensorSSA, + head: cute.TensorSSA, + m_idx: cute.TensorSSA, + n_idx: cute.TensorSSA, + seqlen_info, + aux_tensors, +) -> cute.TensorSSA: + bias = aux_tensors[0] + threshold = utils.scalar_to_ssa(bias[n_idx[0]], cutlass.Int32) + return n_idx >= threshold + + +# ============================================================================= +# mask_mod functions that use global indices (for use with variable sequence length) +# Global indices computed as: m_idx_global = m_idx + seqlen_info.offset_q +# n_idx_global = n_idx + seqlen_info.offset_k +# ============================================================================= + +# TODO: Add varlen mask implementations here + + +# ============================================================================= +# Eager reference functions (PyTorch/Flex Attention signatures) +# ============================================================================= + + +def get_flex_causal_mask(offset: int): + def _flex_causal_mask(b, h, q_idx, kv_idx): + return kv_idx <= q_idx + offset + + return _flex_causal_mask + + +def get_flex_block_causal_mask(offset: int): + def _flex_block_causal_mask(b, h, q_idx, kv_idx): + return kv_idx <= q_idx + offset + + return _flex_block_causal_mask + + +def get_flex_sliding_window_mask(window_left: int, window_right: int, offset: int): + def _flex_sliding_window_mask(b, h, q_idx, kv_idx): + center = q_idx + offset + lower = center - window_left + upper = center + window_right + return (kv_idx >= lower) & (kv_idx <= upper) + + return _flex_sliding_window_mask + + +def flex_block_diagonal_mask(b, h, q_idx, kv_idx): + block_size = 128 + return (q_idx // block_size) == (kv_idx // block_size) + + +def flex_mini_causal_mask(b, h, q_idx, kv_idx): + return (q_idx % 128) >= (kv_idx % 128) + + +def flex_prefix_lm_mask(b, h, q_idx, kv_idx): + """Prefix LM mask: first 512 tokens attend bidirectionally, rest use causal masking.""" + prefix_size = 512 + both_in_prefix = (q_idx < prefix_size) & (kv_idx < prefix_size) + causal_part = q_idx >= kv_idx + return both_in_prefix | causal_part + + +def flex_dilated_sliding_window_mask(b, h, q_idx, kv_idx): + """Dilated sliding window: every other position in a 256-position window.""" + window_size = 256 + dilation = 2 + in_window = (q_idx >= kv_idx) & (q_idx - kv_idx < window_size) + dilated = ((q_idx - kv_idx) % dilation) == 0 + return in_window & dilated + + +def flex_document_mask(b, h, q_idx, kv_idx, doc_id): + return doc_id[b, h, q_idx] == doc_id[b, h, kv_idx] + + +def flex_ima_mask(b, h, q_idx, kv_idx, bias): + return kv_idx >= bias[kv_idx] + + +# ============================================================================= +# Utility functions +# ============================================================================= + + +def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): + """Generate synthetic document ids shared across heads.""" + doc_ids_tensor = torch.zeros(batch, nheads, seqlen_q, dtype=torch.int32, device=device) + for b in range(batch): + N = seqlen_q + max_segments = max(1, math.ceil(math.sqrt(max(N // 4, 1)))) + n = random.randint(1, max_segments) + n = min(n, N) + cuts = sorted(random.sample(range(1, N), n - 1)) + lengths = [b - a for a, b in zip((0, *cuts), (*cuts, N))] + base_doc_ids = torch.repeat_interleave( + torch.arange(len(lengths), device=device, dtype=torch.int32), + torch.tensor(lengths, device=device, dtype=torch.int32), + ) + + for h in range(nheads): + doc_ids_tensor[b, h, :] = base_doc_ids + return doc_ids_tensor + + +# ============================================================================= +# Mask registry and factory functions +# ============================================================================= + + +STATIC_MASKS = { + "block_diagonal": (cute_block_diagonal_mask, flex_block_diagonal_mask), + "mini_causal": (cute_mini_causal_mask, flex_mini_causal_mask), + "prefix_lm": (cute_prefix_lm_mask, flex_prefix_lm_mask), + "dilated_sliding_window": ( + cute_dilated_sliding_window_mask, + flex_dilated_sliding_window_mask, + ), + "document": (cute_document_mask, flex_document_mask), + "ima": (cute_ima_mask, flex_ima_mask), +} + +PARAMETERIZED_MASK_FACTORIES = { + "causal": (get_cute_causal_mask, get_flex_causal_mask), + "block_causal": (get_cute_block_causal_mask, get_flex_block_causal_mask), + "sliding_window": (get_cute_sliding_window_mask, get_flex_sliding_window_mask), +} + + +def get_mask_pair(mask_name, seqlen_q=None, seqlen_k=None, window_size=None): + """Get (cute_mask, flex_mask) pair for the given mask name. + + For static masks, seqlen info is not needed. + For parameterized masks, seqlen_q and seqlen_k are required. + """ + if mask_name in STATIC_MASKS: + return STATIC_MASKS[mask_name] + + if mask_name not in PARAMETERIZED_MASK_FACTORIES: + raise ValueError(f"Unknown mask: {mask_name}") + + if seqlen_q is None or seqlen_k is None: + raise ValueError( + f"Parameterized mask '{mask_name}' requires seqlen_q and seqlen_k" + ) + + cute_factory, flex_factory = PARAMETERIZED_MASK_FACTORIES[mask_name] + offset = seqlen_k - seqlen_q + + if mask_name == "sliding_window": + if window_size is None: + raise ValueError("sliding_window mask requires window_size parameter") + cute_mask = cute_factory(window_size, window_size, offset) + flex_mask = flex_factory(window_size, window_size, offset) + else: + cute_mask = cute_factory(offset) + flex_mask = flex_factory(offset) + + return cute_mask, flex_mask + + +if __name__ == "__main__": + doc_ids = random_doc_id_tensor(1, 2, 128) + print(f"{doc_ids = }") diff --git a/tests/cute/score_mod_definitions.py b/tests/cute/score_mod_definitions.py new file mode 100644 index 00000000000..be6333a6448 --- /dev/null +++ b/tests/cute/score_mod_definitions.py @@ -0,0 +1,591 @@ +import torch +import cutlass +import cutlass.cute as cute +from cutlass._mlir.dialects import math as mlir_math +import operator + +# ============================================================================= +# Score_mod functions that don't use global indices +# All use signature: (tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors) +# ============================================================================= + + +@cute.jit +def score_mod_identity(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + return tSrS_ssa + + +@cute.jit +def score_mod_causal(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + mask = operator.ge(q_idx, kv_idx) + return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf"))) + + +@cute.jit +def score_mod_rel_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + diff = q_idx - kv_idx + abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype) + return tSrS_ssa + abs_diff.to(cutlass.Float32) + + +@cute.jit +def score_mod_rel_bias_x2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + diff = q_idx - kv_idx + abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype) + scaled = abs_diff * cute.full_like(abs_diff, 2) + return tSrS_ssa + scaled.to(cutlass.Float32) + + +@cute.jit +def score_mod_times_two(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + return tSrS_ssa * cute.full_like(tSrS_ssa, 2) + + +@cute.jit +def score_mod_alibi(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + score = tSrS_ssa.to(cutlass.Float32) + slope_exp = (h_idx + cute.full_like(h_idx, 1)) * cute.full_like(h_idx, -8) + slope = cute.math.exp2( + slope_exp.to(cutlass.Float32) + * cute.full_like(score, 0.125 * 0.6931471805599453 * 1.4426950408889634) + ) + diff = q_idx - kv_idx + abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype).to(cutlass.Float32) + return score - slope * abs_diff + + +@cute.jit +def score_mod_sliding_window(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + diff = q_idx - kv_idx + abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype) + mask = operator.le(abs_diff, cute.full_like(abs_diff, 256)) + return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf"))) + + +@cute.jit +def score_mod_block_diagonal(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + q_block = q_idx // 64 + kv_block = kv_idx // 64 + mask = operator.eq(q_block, kv_block) + return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf"))) + + +@cute.jit +def score_mod_causal_v2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + diff = q_idx - kv_idx + mask = operator.ge(diff, cute.full_like(diff, 0)) + return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf"))) + + +@cute.jit +def score_mod_batch_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + batch_bias = aux_tensors[0] + dtype = batch_bias.element_type + b_frag = cute.make_fragment(1, cutlass.Int32) + b_frag.store(b_idx) + bias_frag = cute.make_fragment(1, dtype) + bias_frag[0] = batch_bias[b_frag[0]] + bias_val = (bias_frag.load()).to(cutlass.Float32) + return tSrS_ssa + bias_val + + +@cute.jit +def score_mod_dual_buffer(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + head_bias = aux_tensors[0] + pos_bias = aux_tensors[1] + dtype = head_bias.element_type + + h_frag = cute.make_fragment(1, cutlass.Int32) + h_frag.store(h_idx) + head_val_frag = cute.make_fragment(1, dtype) + head_val_frag[0] = head_bias[h_frag[0]] + head_val = (head_val_frag.load()).to(cutlass.Float32) + + q_frag = cute.make_fragment(1, cutlass.Int32) + q_frag.store(q_idx) + pos_val_frag = cute.make_fragment(1, dtype) + pos_val_frag[0] = pos_bias[q_frag[0]] + pos_val = (pos_val_frag.load()).to(cutlass.Float32) + + return tSrS_ssa + head_val + pos_val + + +# ============================================================================= +# Score_mod functions that use global indices +# All use signature: (tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors) +# Global indices computed as: q_idx_global = q_idx + seqlen_info.offset_q (and similarly for kv) +# ============================================================================= + + +@cute.jit +def score_mod_global_kv_bias( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + """Per-token bias using global kv index.""" + offset_k = seqlen_info.offset_k + kv_idx_global = kv_idx + offset_k + token_bias = aux_tensors[0] + dtype = token_bias.element_type + kv_frag = cute.make_fragment(1, cutlass.Int32) + kv_frag.store(kv_idx_global) + bias_frag = cute.make_fragment(1, dtype) + bias_frag[0] = token_bias[kv_frag[0]] + + return tSrS_ssa + (bias_frag.load()).to(cutlass.Float32) + + +@cute.jit +def score_mod_global_q_bias( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + """Per-token bias using global q index.""" + offset_q = seqlen_info.offset_q + q_idx_global = q_idx + offset_q + token_bias = aux_tensors[0] + dtype = token_bias.element_type + q_frag = cute.make_fragment(1, cutlass.Int32) + q_frag.store(q_idx_global) + bias_frag = cute.make_fragment(1, dtype) + bias_frag[0] = token_bias[q_frag[0]] + return tSrS_ssa + (bias_frag.load()).to(cutlass.Float32) + + +@cute.jit +def score_mod_global_rel_plus_kv_bias( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + """Relative position (logical) + per-token bias (global kv).""" + offset_k = seqlen_info.offset_k + kv_idx_global = kv_idx + offset_k + token_bias = aux_tensors[0] + dtype = token_bias.element_type + + rel_pos = q_idx - kv_idx + rel_pos_abs = cute.TensorSSA(mlir_math.absi(rel_pos), rel_pos.shape, rel_pos.dtype) + rel_bias = rel_pos_abs.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.1) + + kv_frag = cute.make_fragment(1, cutlass.Int32) + kv_frag.store(kv_idx_global) + bias_frag = cute.make_fragment(1, dtype) + bias_frag[0] = token_bias[kv_frag[0]] + + return tSrS_ssa + rel_bias + (bias_frag.load()).to(cutlass.Float32) + + +@cute.jit +def score_mod_global_q_and_kv_bias( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + """Both q and kv global indices.""" + offset_q = seqlen_info.offset_q + q_idx_global = q_idx + offset_q + offset_k = seqlen_info.offset_k + kv_idx_global = kv_idx + offset_k + q_bias = aux_tensors[0] + kv_bias = aux_tensors[1] + dtype = q_bias.element_type + + q_frag = cute.make_fragment(1, cutlass.Int32) + q_frag.store(q_idx_global) + q_bias_frag = cute.make_fragment(1, dtype) + q_bias_frag[0] = q_bias[q_frag[0]] + + kv_frag = cute.make_fragment(1, cutlass.Int32) + kv_frag.store(kv_idx_global) + kv_bias_frag = cute.make_fragment(1, dtype) + kv_bias_frag[0] = kv_bias[kv_frag[0]] + + return ( + tSrS_ssa + + (q_bias_frag.load()).to(cutlass.Float32) + + (kv_bias_frag.load()).to(cutlass.Float32) + ) + + +@cute.jit +def score_mod_global_logical_rel_plus_kv_bias( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + """Logical relative + global-indexed per-token bias.""" + offset_k = seqlen_info.offset_k + kv_idx_global = kv_idx + offset_k + token_bias = aux_tensors[0] + dtype = token_bias.element_type + + rel_pos = q_idx - kv_idx + rel_pos_abs = cute.TensorSSA(mlir_math.absi(rel_pos), rel_pos.shape, rel_pos.dtype) + rel_bias = rel_pos_abs.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.01) + + kv_frag = cute.make_fragment(1, cutlass.Int32) + kv_frag.store(kv_idx_global) + bias_frag = cute.make_fragment(1, dtype) + bias_frag[0] = token_bias[kv_frag[0]] + + return tSrS_ssa + rel_bias + (bias_frag.load()).to(cutlass.Float32) + + +# "Stress tests" - score_mods with complex global index usage + +@cute.jit +def score_mod_stress_complex_arithmetic( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + """All indices in complex arithmetic.""" + offset_q = seqlen_info.offset_q + q_idx_global = q_idx + offset_q + bias = aux_tensors[0] + dtype = bias.element_type + + # Use absolute value instead of squaring to avoid overflow with large sequences + rel_pos = q_idx - kv_idx + rel_pos_abs = cute.TensorSSA(mlir_math.absi(rel_pos), rel_pos.shape, rel_pos.dtype) + rel_bias = rel_pos_abs.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.001) + + q_frag = cute.make_fragment(1, cutlass.Int32) + q_frag.store(q_idx_global) + bias_q_frag = cute.make_fragment(1, dtype) + bias_q_frag[0] = bias[q_frag[0]] + bias_q = (bias_q_frag.load()).to(cutlass.Float32) + + scale = (b_idx + cute.full_like(b_idx, 1)) * (h_idx + cute.full_like(h_idx, 1)) + scale_f32 = scale.to(cutlass.Float32) * 0.001 + + result = tSrS_ssa + rel_bias + bias_q * scale_f32 + return result + + +@cute.jit +def score_mod_stress_conditional_mask( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + """Conditional masking with global vs logical.""" + offset_q = seqlen_info.offset_q + q_idx_global = q_idx + offset_q + offset_k = seqlen_info.offset_k + kv_idx_global = kv_idx + offset_k + token_bias = aux_tensors[0] + dtype = token_bias.element_type + + kv_frag = cute.make_fragment(1, cutlass.Int32) + kv_frag.store(kv_idx_global) + bias_frag = cute.make_fragment(1, dtype) + bias_frag[0] = token_bias[kv_frag[0]] + bias_val = (bias_frag.load()).to(cutlass.Float32) + + is_causal = operator.ge(q_idx, kv_idx) + + global_diff = q_idx_global - kv_idx_global + is_nearby = operator.le( + cute.TensorSSA(mlir_math.absi(global_diff), global_diff.shape, global_diff.dtype), + cute.full_like(global_diff, 512), + ) + + both_conditions = is_causal & is_nearby + return cute.where(both_conditions, tSrS_ssa + bias_val, cute.full_like(tSrS_ssa, float("-inf"))) + + +@cute.jit +def score_mod_stress_multi_buffer( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + """Multiple aux tensors with different indexing.""" + offset_q = seqlen_info.offset_q + q_idx_global = q_idx + offset_q + offset_k = seqlen_info.offset_k + kv_idx_global = kv_idx + offset_k + batch_bias = aux_tensors[0] + head_scale = aux_tensors[1] + q_pos_bias = aux_tensors[2] + kv_pos_bias = aux_tensors[3] + rel_pos_scale = aux_tensors[4] + + dtype = batch_bias.element_type + + b_frag = cute.make_fragment(1, cutlass.Int32) + b_frag.store(b_idx) + bb_frag = cute.make_fragment(1, dtype) + bb_frag[0] = batch_bias[b_frag[0]] + bb_val = (bb_frag.load()).to(cutlass.Float32) + + h_frag = cute.make_fragment(1, cutlass.Int32) + h_frag.store(h_idx) + hs_frag = cute.make_fragment(1, dtype) + hs_frag[0] = head_scale[h_frag[0]] + hs_val = (hs_frag.load()).to(cutlass.Float32) + + qg_frag = cute.make_fragment(1, cutlass.Int32) + qg_frag.store(q_idx_global) + qpb_frag = cute.make_fragment(1, dtype) + qpb_frag[0] = q_pos_bias[qg_frag[0]] + qpb_val = (qpb_frag.load()).to(cutlass.Float32) + + kvg_frag = cute.make_fragment(1, cutlass.Int32) + kvg_frag.store(kv_idx_global) + kvpb_frag = cute.make_fragment(1, dtype) + kvpb_frag[0] = kv_pos_bias[kvg_frag[0]] + kvpb_val = (kvpb_frag.load()).to(cutlass.Float32) + + rel_idx = q_idx - kv_idx + cute.full_like(q_idx, 512) + rel_idx_clamped = cute.where( + operator.lt(rel_idx, cute.full_like(rel_idx, 0)), cute.full_like(rel_idx, 0), rel_idx + ) + rel_idx_clamped = cute.where( + operator.gt(rel_idx_clamped, cute.full_like(rel_idx_clamped, 1024)), + cute.full_like(rel_idx_clamped, 1024), + rel_idx_clamped, + ) + ri_frag = cute.make_fragment(1, cutlass.Int32) + ri_frag.store(rel_idx_clamped) + rps_frag = cute.make_fragment(1, dtype) + rps_frag[0] = rel_pos_scale[ri_frag[0]] + rps_val = (rps_frag.load()).to(cutlass.Float32) + + return tSrS_ssa * hs_val + bb_val + qpb_val + kvpb_val + rps_val * cute.full_like(tSrS_ssa, 0.1) + + +@cute.jit +def score_mod_stress_global_offset( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + """Verify global - logical = offset.""" + offset_k = seqlen_info.offset_k + kv_idx_global = kv_idx + offset_k + token_bias = aux_tensors[0] + dtype = token_bias.element_type + + kv_frag = cute.make_fragment(1, cutlass.Int32) + kv_frag.store(kv_idx_global) + bias_frag = cute.make_fragment(1, dtype) + bias_frag[0] = token_bias[kv_frag[0]] + + return tSrS_ssa + (bias_frag.load()).to(cutlass.Float32) + + +@cute.jit +def score_mod_stress_xor_pattern( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + """XOR-based pattern using index bits.""" + offset_k = seqlen_info.offset_k + kv_idx_global = kv_idx + offset_k + token_bias = aux_tensors[0] + dtype = token_bias.element_type + + xor_logical = q_idx ^ kv_idx + pattern_logical = xor_logical & cute.full_like(xor_logical, 0xFF) + pattern_bias = pattern_logical.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.001) + + kv_frag = cute.make_fragment(1, cutlass.Int32) + kv_frag.store(kv_idx_global) + bias_frag = cute.make_fragment(1, dtype) + bias_frag[0] = token_bias[kv_frag[0]] + + return ( + tSrS_ssa + + pattern_bias + + (bias_frag.load()).to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.1) + ) + + +@cute.jit +def score_mod_debug_global_idx( + tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors +): + # Don't read from aux_tensors at all - just add the global index as bias + offset_k = seqlen_info.offset_k + kv_idx_global = kv_idx + offset_k + bias = kv_idx_global.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.001) + return tSrS_ssa + bias + + +# ============================================================================= +# Eager reference functions +# ============================================================================= + + +def identity_eager(score, b, h, q_idx, kv_idx): + return score + + +def causal_eager(score, b, h, q_idx, kv_idx): + return torch.where(q_idx >= kv_idx, score, float("-inf")) + + +def rel_bias_eager(score, b, h, q_idx, kv_idx): + return score + torch.abs(q_idx - kv_idx) + + +def rel_bias_x2_eager(score, b, h, q_idx, kv_idx): + return score + 2 * torch.abs(q_idx - kv_idx) + + +def times_two_eager(score, b, h, q_idx, kv_idx): + return score * 2 + + +def alibi_eager(score, b, h, q_idx, kv_idx): + slope = 2 ** (-8 * (h + 1) / 8) + return score - slope * torch.abs(q_idx - kv_idx) + + +def sliding_window_eager(score, b, h, q_idx, kv_idx): + return torch.where(torch.abs(q_idx - kv_idx) <= 256, score, float("-inf")) + + +def block_diagonal_eager(score, b, h, q_idx, kv_idx): + return torch.where(q_idx // 64 == kv_idx // 64, score, float("-inf")) + + +def causal_v2_eager(score, b, h, q_idx, kv_idx): + return torch.where(q_idx - kv_idx >= 0, score, float("-inf")) + + +def batch_bias_factory(bias_tensor): + def mod(score, b, h, q_idx, kv_idx): + return score + bias_tensor[b] + + return mod + + +def dual_buffer_factory(head_bias, pos_bias): + def mod(score, b, h, q_idx, kv_idx): + return score + head_bias[h] + pos_bias[q_idx] + + return mod + + +def packed_kv_bias_factory(bias_tensor, cu_seqlens_k): + def mod(score, b, h, q_idx, kv_idx): + # Calculate valid length for this sequence + start = cu_seqlens_k[b] + seq_len = cu_seqlens_k[b+1] - start + + # Clamp kv_idx. + safe_kv_idx = torch.clamp(kv_idx, max=seq_len - 1) + + return score + bias_tensor[start + safe_kv_idx] + return mod + + +def packed_q_bias_factory(bias_tensor, cu_seqlens_q): + def mod(score, b, h, q_idx, kv_idx): + start = cu_seqlens_q[b] + seq_len = cu_seqlens_q[b+1] - start + + # Clamp q_idx + safe_q_idx = torch.clamp(q_idx, max=seq_len - 1) + + return score + bias_tensor[start + safe_q_idx] + return mod + + +def packed_rel_plus_kv_bias_factory(bias_tensor, cu_seqlens_k): + def mod(score, b, h, q_idx, kv_idx): + start = cu_seqlens_k[b] + seq_len = cu_seqlens_k[b+1] - start + + # Clamp kv_idx + safe_kv_idx = torch.clamp(kv_idx, max=seq_len - 1) + + rel_bias = torch.abs(q_idx - kv_idx).float() * 0.1 + return score + rel_bias + bias_tensor[start + safe_kv_idx] + + return mod + + +def packed_q_and_kv_bias_factory(q_bias, kv_bias, cu_seqlens_q, cu_seqlens_k): + def mod(score, b, h, q_idx, kv_idx): + # Handle Q bounds + q_start = cu_seqlens_q[b] + q_len = cu_seqlens_q[b+1] - q_start + safe_q_idx = torch.clamp(q_idx, max=q_len - 1) + + # Handle KV bounds + kv_start = cu_seqlens_k[b] + kv_len = cu_seqlens_k[b+1] - kv_start + safe_kv_idx = torch.clamp(kv_idx, max=kv_len - 1) + + return score + q_bias[q_start + safe_q_idx] + kv_bias[kv_start + safe_kv_idx] + + return mod + + +def packed_logical_rel_plus_kv_bias_factory(bias_tensor, cu_seqlens_k): + def mod(score, b, h, q_idx, kv_idx): + rel_bias = torch.abs(q_idx - kv_idx).float() * 0.01 + return score + rel_bias + bias_tensor[cu_seqlens_k[b] + kv_idx] + + return mod + + +def stress_complex_arithmetic_factory(bias, cu_seqlens_q): + def mod(score, b, h, q_idx, kv_idx): + # Use absolute value instead of squaring to avoid overflow with large sequences + rel_pos_abs = torch.abs(q_idx - kv_idx) + q_global = cu_seqlens_q[b] + q_idx + bias_q = bias[q_global] + scale = (b + 1) * (h + 1) * 0.001 + rel_bias = rel_pos_abs * 0.001 + return score + rel_bias + bias_q * scale + + return mod + + +def stress_conditional_mask_factory(token_bias, cu_seqlens_q, cu_seqlens_k): + def mod(score, b, h, q_idx, kv_idx): + kv_global = cu_seqlens_k[b] + kv_idx + bias_val = token_bias[kv_global] + is_causal = q_idx >= kv_idx + q_global = cu_seqlens_q[b] + q_idx + global_diff = q_global - kv_global + is_nearby = torch.abs(global_diff) <= 512 + both_conditions = is_causal & is_nearby + return torch.where(both_conditions, score + bias_val, float("-inf")) + + return mod + + +def stress_multi_buffer_factory( + batch_bias, + head_scale, + q_pos_bias, + kv_pos_bias, + rel_pos_scale, + cu_seqlens_q, + cu_seqlens_k, + max_rel_pos=512, +): + def mod(score, b, h, q_idx, kv_idx): + bb_val = batch_bias[b] + hs_val = head_scale[h] + qpb_val = q_pos_bias[cu_seqlens_q[b] + q_idx] + kvpb_val = kv_pos_bias[cu_seqlens_k[b] + kv_idx] + rel_idx = (q_idx - kv_idx + max_rel_pos).clamp(0, max_rel_pos * 2) + rps_val = rel_pos_scale[rel_idx] + return score * hs_val + bb_val + qpb_val + kvpb_val + rps_val * 0.1 + + return mod + + +def stress_global_offset_factory(token_bias, cu_seqlens_k): + def mod(score, b, h, q_idx, kv_idx): + return score + token_bias[cu_seqlens_k[b] + kv_idx] + + return mod + + +def stress_xor_pattern_factory(token_bias, cu_seqlens_q, cu_seqlens_k): + def mod(score, b, h, q_idx, kv_idx): + xor_logical = q_idx ^ kv_idx + pattern_bias = (xor_logical & 0xFF).float() * 0.001 + kv_global = cu_seqlens_k[b] + kv_idx + return score + pattern_bias + token_bias[kv_global] * 0.1 + + return mod + +def debug_global_idx_factory(bias, cu_seqlens_k): + offsets = cu_seqlens_k.tolist() + def mod(score, b, h, q_idx, kv_idx): + global_kv = offsets[b] + kv_idx + return score + global_kv.float() * 0.001 + return mod diff --git a/tests/cute/test_block_sparsity.py b/tests/cute/test_block_sparsity.py new file mode 100644 index 00000000000..06af8d658c2 --- /dev/null +++ b/tests/cute/test_block_sparsity.py @@ -0,0 +1,485 @@ +"""Tests for block sparsity computation in flash attention.""" + +import pytest +import torch +from torch.nn.attention.flex_attention import create_block_mask + +from mask_mod_definitions import get_mask_pair +from flash_attn.cute.compute_block_sparsity import compute_block_sparsity + + +def _call_compute_block_sparsity( + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + mask_name, + window_size=None, + aux_tensors=None, + use_fast_sampling=False, +): + """Call compute_block_sparsity and return torch tensors.""" + cute_mask, _ = get_mask_pair( + mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size + ) + _, torch_tensors = compute_block_sparsity( + tile_m=tile_m, + tile_n=tile_n, + batch_size=batch_size, + num_heads=nheads, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + mask_mod=cute_mask, + aux_tensors=aux_tensors, + device="cuda", + use_fast_sampling=use_fast_sampling, + ) + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = torch_tensors + return mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx + + +def _compare_block_sparsity( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, +): + """Compare block sparsity against reference, handling boundary block semantics. + + PyTorch treats OOB regions as masked, so boundary blocks with all in-bounds + elements unmasked appear as "partial" in PyTorch but "full" in CuTe. + + This applies to BOTH boundary m_blocks (OOB q_idx) and boundary n_blocks (OOB kv_idx). + """ + if not isinstance(mask_block_cnt, torch.Tensor): + return False, f"mask_block_cnt is not a tensor: {type(mask_block_cnt)}" + + n_blocks_q = mask_block_cnt.shape[2] + + # Identify boundary blocks + last_m_block = (seqlen_q - 1) // tile_m + last_n_block = (seqlen_k - 1) // tile_n + m_is_boundary = seqlen_q % tile_m != 0 + n_is_boundary = seqlen_k % tile_n != 0 + + def is_boundary_n_block(n_block): + return n_is_boundary and n_block == last_n_block + + def is_boundary_m_block(m_block): + return m_is_boundary and m_block == last_m_block + + for b in range(batch_size): + for h in range(nheads): + for m in range(n_blocks_q): + cute_mask_cnt = mask_block_cnt[b, h, m].item() + cute_full_cnt = full_block_cnt[b, h, m].item() + ref_mask_cnt = mask_block_cnt_ref[b, h, m].item() + ref_full_cnt = full_block_cnt_ref[b, h, m].item() + + cute_mask_set = set(mask_block_idx[b, h, m, :cute_mask_cnt].tolist()) + cute_full_set = set(full_block_idx[b, h, m, :cute_full_cnt].tolist()) + ref_mask_set = set(mask_block_idx_ref[b, h, m, :ref_mask_cnt].tolist()) + ref_full_set = set(full_block_idx_ref[b, h, m, :ref_full_cnt].tolist()) + + # A block is "boundary-affected" if EITHER the m_block OR n_block is at boundary + def is_boundary_affected(n_block): + return is_boundary_m_block(m) or is_boundary_n_block(n_block) + + # Blocks that are full in CuTe but not in ref + full_in_cute_not_ref = cute_full_set - ref_full_set + + for n_block in full_in_cute_not_ref: + if not is_boundary_affected(n_block): + return False, ( + f"Non-boundary block mismatch at [{b},{h},{m}]: " + f"n_block {n_block} is full in CuTe but not in ref" + ) + # Boundary-affected: CuTe says full, ref should say partial + if n_block not in ref_mask_set: + # Check if ref skipped it entirely (all masked) + # This is valid for boundary blocks + pass + + # Blocks that are partial in CuTe but full in ref (would be a bug) + partial_in_cute_full_in_ref = cute_mask_set & ref_full_set + if partial_in_cute_full_in_ref: + return False, ( + f"Block mismatch at [{b},{h},{m}]: " + f"n_blocks {sorted(partial_in_cute_full_in_ref)} are partial in CuTe but full in ref" + ) + + # Check non-boundary blocks match exactly + non_boundary_cute_full = { + n for n in cute_full_set if not is_boundary_affected(n) + } + non_boundary_ref_full = { + n for n in ref_full_set if not is_boundary_affected(n) + } + if non_boundary_cute_full != non_boundary_ref_full: + return False, ( + f"Non-boundary full block mismatch at [{b},{h},{m}]: " + f"CuTe={sorted(non_boundary_cute_full)}, ref={sorted(non_boundary_ref_full)}" + ) + + non_boundary_cute_mask = { + n for n in cute_mask_set if not is_boundary_affected(n) + } + non_boundary_ref_mask = { + n for n in ref_mask_set if not is_boundary_affected(n) + } + if non_boundary_cute_mask != non_boundary_ref_mask: + return False, ( + f"Non-boundary partial block mismatch at [{b},{h},{m}]: " + f"CuTe={sorted(non_boundary_cute_mask)}, ref={sorted(non_boundary_ref_mask)}" + ) + + return True, "" + + +# Test configurations +SEQLEN_PAIRS = [ + # Small aligned + (64, 64), + (128, 128), + (256, 256), + (512, 512), + # Rectangular + (128, 256), + (256, 128), + (512, 256), + (256, 512), + # Large aligned + (1024, 1024), + (2048, 2048), + (4096, 4096), + (8192, 8192), + # Large unaligned + (1000, 1000), + (2000, 2000), + (4000, 4000), + # Edge cases with unaligned seqlens + (113, 203), + (127, 127), + (129, 129), + (255, 255), + (257, 257), + (1023, 1023), + (1025, 1025), + (2047, 2047), + (2049, 2049), +] +TILE_SIZES = [ + # Standard powers of 2 + (32, 32), + (64, 64), + (128, 128), + (256, 256), + # Rectangular + (32, 64), + (64, 32), + (64, 128), + (128, 64), + (128, 256), + (256, 128), + # Unusual sizes + (40, 40), + (48, 48), + (96, 96), + (112, 112), + (32, 128), + (128, 32), + (40, 96), + (96, 40), +] + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS) +@pytest.mark.parametrize("tile_m,tile_n", TILE_SIZES) +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("nheads", [1, 4]) +@pytest.mark.parametrize("mask_name", ["block_diagonal", "mini_causal"]) +def test_fixed_length_masks( + seqlen_q, seqlen_k, tile_m, tile_n, batch_size, nheads, mask_name +): + """Test fixed-length masks.""" + seqlen_unaligned = (seqlen_q % tile_m != 0) or (seqlen_k % tile_n != 0) + + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = ( + _call_compute_block_sparsity( + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + mask_name, + use_fast_sampling=False, + ) + ) + + _, mask_mod_flex = get_mask_pair(mask_name) + block_mask = create_block_mask( + mask_mod_flex, + B=batch_size, + H=nheads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_k, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + ( + _, + _, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + *_, + ) = block_mask.as_tuple() + + print("CuTe results:") + print(f" mask_block_cnt: {mask_block_cnt}") + print(f" full_block_cnt: {full_block_cnt}") + print(f" mask_block_idx: {mask_block_idx}") + print(f" full_block_idx: {full_block_idx}") + print("Torch results:") + print(f" mask_block_cnt: {mask_block_cnt_ref}") + print(f" full_block_cnt: {full_block_cnt_ref}") + print(f" mask_block_idx: {mask_block_idx_ref}") + print(f" full_block_idx: {full_block_idx_ref}") + + all_match, error_msg = _compare_block_sparsity( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + ) + assert all_match, f"Mismatch: {error_msg}" + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS) +@pytest.mark.parametrize( + "tile_m,tile_n", [(64, 64), (128, 128), (64, 128), (128, 64), (256, 256)] +) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("nheads", [1, 4]) +@pytest.mark.parametrize( + "mask_name,window_size", + [("causal", None), ("sliding_window", 64), ("sliding_window", 256)], +) +def test_parameterized_masks( + seqlen_q, seqlen_k, tile_m, tile_n, batch_size, nheads, mask_name, window_size +): + """Test parameterized masks.""" + if mask_name == "sliding_window" and seqlen_q > seqlen_k: + pytest.skip("Sliding window not supported for seqlen_q > seqlen_k") + + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = ( + _call_compute_block_sparsity( + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + mask_name, + window_size=window_size, + ) + ) + + _, mask_mod_flex = get_mask_pair( + mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size + ) + block_mask = create_block_mask( + mask_mod_flex, + B=batch_size, + H=nheads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_k, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + ( + _, + _, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + *_, + ) = block_mask.as_tuple() + + all_match, error_msg = _compare_block_sparsity( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + ) + + assert all_match, f"Mismatch: {error_msg}" + + +@pytest.mark.parametrize( + "seqlen_q,seqlen_k,tile_m,tile_n", + [ + (1, 1, 64, 64), + (63, 63, 64, 64), + (65, 65, 64, 64), + (129, 129, 128, 128), + (100, 200, 64, 128), + ], +) +def test_edge_cases(seqlen_q, seqlen_k, tile_m, tile_n): + """Test edge cases with unaligned dimensions.""" + batch_size, nheads = 1, 1 + seqlen_unaligned = (seqlen_q % tile_m != 0) or (seqlen_k % tile_n != 0) + + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = ( + _call_compute_block_sparsity( + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + "causal", + ) + ) + + _, mask_mod_flex = get_mask_pair("causal", seqlen_q=seqlen_q, seqlen_k=seqlen_k) + block_mask = create_block_mask( + mask_mod_flex, + B=batch_size, + H=nheads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_k, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + ( + _, + _, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + *_, + ) = block_mask.as_tuple() + + all_match, error_msg = _compare_block_sparsity( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + ) + assert all_match, f"Mismatch: {error_msg}" + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS) +@pytest.mark.parametrize( + "tile_m,tile_n", [(64, 64), (128, 128), (64, 128), (128, 64), (256, 256)] +) +@pytest.mark.parametrize("nheads", [1, 4]) +@pytest.mark.parametrize("mask_name", ["causal", "block_diagonal"]) +def test_fast_sampling(seqlen_q, seqlen_k, tile_m, tile_n, nheads, mask_name): + """Test fast sampling mode (5-point sampling).""" + batch_size = 1 + seqlen_unaligned = (seqlen_q % tile_m != 0) or (seqlen_k % tile_n != 0) + + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = ( + _call_compute_block_sparsity( + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + mask_name, + use_fast_sampling=True, + ) + ) + + _, mask_mod_flex = get_mask_pair(mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k) + block_mask = create_block_mask( + mask_mod_flex, + B=batch_size, + H=nheads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_k, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + ( + _, + _, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + *_, + ) = block_mask.as_tuple() + + all_match, error_msg = _compare_block_sparsity( + mask_block_cnt, + mask_block_idx, + full_block_cnt, + full_block_idx, + mask_block_cnt_ref, + mask_block_idx_ref, + full_block_cnt_ref, + full_block_idx_ref, + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + ) + + assert all_match, f"Mismatch: {error_msg}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py new file mode 100644 index 00000000000..1c2088dd28a --- /dev/null +++ b/tests/cute/test_flash_attn.py @@ -0,0 +1,1522 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + +import math +import itertools +import os + +import pytest +import torch + +from einops import rearrange, repeat + +try: + from flash_attn.layers.rotary import apply_rotary_emb +except ImportError: + apply_rotary_emb = None + +from flash_attn.cute.testing import ( + attention_ref, + generate_qkv, + generate_random_padding_mask, + pad_input, + unpad_input, +) +from flash_attn.cute.interface import ( + flash_attn_func, + flash_attn_varlen_func, + flash_attn_combine, + _get_device_capability, +) + + +DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" +# SplitKV and paged KV are not supported on SM90 +IS_SM90 = _get_device_capability() == 9 +TEST_BWD_ONLY = False +VERBOSE = True + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +# @pytest.mark.parametrize("has_learnable_sink", [False]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +# @pytest.mark.parametrize("softcap", [0.0, 15.0]) +@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) +# @pytest.mark.parametrize("local_enum", [0]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64, 128, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128, 192]) +# @pytest.mark.parametrize("d", [128, 192]) +@pytest.mark.parametrize("d", [64, 128]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (3, 3), + (64, 32), + (64, 128), + (128, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + (4096, 4096), + (4224, 4224), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +def test_flash_attn_output( + seqlen_q, + seqlen_k, + d, + causal, + local_enum, + softcap, + deterministic, + has_qv, + has_learnable_sink, + mha_type, + dtype, +): + local = local_enum > 0 + if local and causal: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(0) + torch.cuda.empty_cache() + torch.cuda.synchronize() + batch_size = 9 if seqlen_k <= 2048 else 2 + # batch_size = 2 + nheads = 6 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) + if dtype == torch.float8_e4m3fn or TEST_BWD_ONLY: + dv_vals = [d] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = q_ref * softcap / 4 + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + if has_qv: + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = ( + (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + ) + if local_enum == 2: + window_size = (None, -window_size[1]) + elif local_enum == 3: + window_size = (-window_size[0], None) + if local: + print("window size = ", window_size) + # window_size = (-1, -1) if not local else (16, 0) + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + learnable_sink = None + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + # k_extended = repeat(k_ref, "b s h d -> b s (h k) d", k=nheads // nheads_kv) + # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_extended).float() + # # if qv is not None: + # # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # exp_sum = s_tmp.sum(-1) + # # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) + # # lse_ref = torch.logsumexp(qk, dim=-1) + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # num_splits_vals = [1, 3] + pack_gqa_vals = [False, True, None] if not TEST_BWD_ONLY else [False] + # SplitKV is not supported for hdim >= 192 + # pack_gqa_vals = [False] + num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + # SplitKV not supported on SM90 - skip this iteration + if IS_SM90 and num_splits > 1: + continue + out, lse = flash_attn_func( + q, + k, + v, + causal=causal, + # qv=qv, + # q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + # attention_chunk=attention_chunk, + softcap=softcap, + learnable_sink=learnable_sink, + pack_gqa=pack_gqa, + num_splits=num_splits, + deterministic=deterministic, + ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol + + if ( + dtype != torch.float8_e4m3fn + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + and softcap == 0.0 + and dv == d + and learnable_sink is None + # and False + and not ((causal or local) and seqlen_k < seqlen_q) + ): + # TODO: SM90 backward pass has invalid MMA tile config for d=64 + non-causal + # The m_block_size=80 (non-causal) with head_dim=64 creates an invalid tile. + # Fix requires adjusting m_block_size or MMA config in flash_bwd_sm90.py + if IS_SM90 and d == 64 and not causal: + pytest.xfail("SM90 backward: d=64 + non-causal has invalid MMA tile config (m_block=80)") + # TODO: SM90 backward pass does not support local attention yet + if IS_SM90 and local: + pytest.xfail("SM90 backward: local attention not supported yet") + g = torch.randn_like(out) + # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + # breakpoint() + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad( + out_ref, (q_ref, k_ref, v_ref), g + ) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + if VERBOSE: + diff_dq = (dq - dq_ref).abs() + max_idx = diff_dq.argmax() + coords = torch.unravel_index(max_idx, diff_dq.shape) + print(f"dQ max diff: {diff_dq.max().item()}") + print(f" at coordinates {tuple(c.item() for c in coords)}: dQ={dq[coords].item()}, dQ_ref={dq_ref[coords].item()}") + + diff_dk = (dk - dk_ref).abs() + max_idx = diff_dk.argmax() + coords = torch.unravel_index(max_idx, diff_dk.shape) + print(f"dK max diff: {diff_dk.max().item()}") + print(f" at coordinates {tuple(c.item() for c in coords)}: dK={dk[coords].item()}, dK_ref={dk_ref[coords].item()}") + + diff_dv = (dv - dv_ref).abs() + max_idx = diff_dv.argmax() + coords = torch.unravel_index(max_idx, diff_dv.shape) + print(f"dV max diff: {diff_dv.max().item()}") + print(f" at coordinates {tuple(c.item() for c in coords)}: dV={dv[coords].item()}, dV_ref={dv_ref[coords].item()}") + + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dq - dq_ref).abs().max().item() <= rtol * ( + dq_pt - dq_ref + ).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dk - dk_ref).abs().max().item() <= rtol * ( + dk_pt - dk_ref + ).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dv - dv_ref).abs().max().item() <= rtol * ( + dv_pt - dv_ref + ).abs().max().item() + dv_atol + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_learnable_sink", [False, True]) +@pytest.mark.parametrize("has_learnable_sink", [False]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +# @pytest.mark.parametrize("softcap", [0.0, 15.0]) +@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) +# @pytest.mark.parametrize("local_enum", [0]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("add_unused_qkv", [False, True]) +@pytest.mark.parametrize("add_unused_qkv", [False]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128]) +# @pytest.mark.parametrize("d", [128, 192]) +@pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + # (1, 1), + # (1, 3), + # (2, 1), + (511, 1), + (3, 513), + (64, 128), + (128, 128), + (256, 256), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (307, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +@pytest.mark.parametrize("varlen_mode", ["random", "third", "full"]) +# @pytest.mark.parametrize("varlen_mode", ["full"]) +@pytest.mark.parametrize( + "zero_lengths_q, zero_lengths_k", + [ + (False, False), + (True, False), + (False, True), + (True, True), + ], +) +@pytest.mark.parametrize( + "unpad_q, unpad_kv", + [ + (True, True), + (False, False), + (True, False), + (False, True), + ], +) +def test_flash_attn_varlen_output( + seqlen_q, + seqlen_k, + d, + add_unused_qkv, + causal, + local_enum, + softcap, + deterministic, + has_qv, + has_learnable_sink, + mha_type, + dtype, + varlen_mode, + zero_lengths_q, + zero_lengths_k, + unpad_q, + unpad_kv, +): + local = local_enum > 0 + if local and causal: + pytest.skip() + if ( + causal or local + ): # Right now reference only supports causal attention with seqlen_k == seqlen_q + seqlen_k = seqlen_q + device = "cuda" + # set seed + torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) + batch_size = 49 if seqlen_q <= 1024 else 7 + nheads = 6 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) + if dtype == torch.float8_e4m3fn or TEST_BWD_ONLY: + dv_vals = [d] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + if has_qv: + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = ( + (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + ) + if local_enum == 2: + window_size = (None, window_size[1]) + elif local_enum == 3: + window_size = (window_size[0], None) + if local: + print("window size = ", window_size) + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + learnable_sink = None + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, + batch_size, + device, + mode=varlen_mode, + zero_lengths=zero_lengths_q, + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, + batch_size, + device, + mode=varlen_mode, + zero_lengths=zero_lengths_k, + ) + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask + + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + # query_padding_mask[:] = True + # query_unused_mask = None + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) + + if causal or local: + key_padding_mask = query_padding_mask + + ( + q_unpad, + k_unpad, + v_unpad, + qv_unpad, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + qv, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv( + q, + k, + v, + query_padding_mask, + key_padding_mask, + qv=qv, + kvpacked=False, + query_unused_mask=query_unused_mask, + key_unused_mask=key_unused_mask, + ) + if unpad_q: + print("cu_seqlens_q = ", cu_seqlens_q) + else: + print("seqused_q = ", seqused_q) + if unpad_kv: + print("cu_seqlens_k = ", cu_seqlens_k) + else: + print("seqused_k = ", seqused_k) + q_unpad, k_unpad, v_unpad = [ + x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad) + ] + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + pack_gqa_vals = [False, True, None] if not TEST_BWD_ONLY else [False] + # pack_gqa_vals = [False] + # num_splits_vals = [1, 3] + # SplitKV is not supported for hdim >= 192 + num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + # SplitKV not supported on SM90 - skip this iteration + if IS_SM90 and num_splits > 1: + continue + out_unpad, lse = flash_attn_varlen_func( + q_unpad if unpad_q else q, + k_unpad if unpad_kv else k, + v_unpad if unpad_kv else v, + cu_seqlens_q=cu_seqlens_q if unpad_q else None, + cu_seqlens_k=cu_seqlens_k if unpad_kv else None, + max_seqlen_q=seqlen_q, + max_seqlen_k=seqlen_k, + seqused_q=seqused_q if not unpad_q else None, + seqused_k=seqused_k if not unpad_kv else None, + causal=causal, + # qv=qv_unpad, + # q_descale=q_descale, + # k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + # attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + deterministic=deterministic, + ) + out = output_pad_fn(out_unpad) if unpad_q else out_unpad + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol + + if ( + dtype != torch.float8_e4m3fn + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + and dv == d + and not has_learnable_sink + # and False + ): + g_unpad = torch.randn_like(out_unpad) + # do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) + # import flash_attn_3_cuda + # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen( + # g_unpad, + # q_unpad, + # k_unpad, + # v_unpad, + # out_unpad, + # lse, + # None, + # None, + # None, + # cu_seqlens_q, + # cu_seqlens_k, + # None, None, + # max_seqlen_q, + # max_seqlen_k, + # d ** (-0.5), + # causal, + # window_size[0], window_size[1], + # softcap, + # deterministic, + # 0, # sm_margin + # ) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad( + out_unpad, + ( + q_unpad if unpad_q else q, + k_unpad if unpad_kv else k, + v_unpad if unpad_kv else v, + ), + g_unpad + ) + dq = dq_pad_fn(dq_unpad) if unpad_q else dq_unpad + dk = dk_pad_fn(dk_unpad) if unpad_kv else dk_unpad + dv = dk_pad_fn(dv_unpad) if unpad_kv else dv_unpad + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) + if query_unused_mask is not None: + dq.masked_fill_(q_zero_masking, 0.0) + if not unpad_kv: + dk.masked_fill_(rearrange(~key_padding_mask, "b s -> b s 1 1"), 0.0) + dv.masked_fill_(rearrange(~key_padding_mask, "b s -> b s 1 1"), 0.0) + if not unpad_q: + dq.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + g = output_pad_fn(g_unpad) if unpad_q else g_unpad + + # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() + # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad( + out_ref, (q_ref, k_ref, v_ref), g + ) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + if VERBOSE: + diff_dq = (dq - dq_ref).abs() + max_idx = diff_dq.argmax() + coords = torch.unravel_index(max_idx, diff_dq.shape) + print(f"dQ max diff: {diff_dq.max().item()}") + print(f" at coordinates {tuple(c.item() for c in coords)}: dQ={dq[coords].item()}, dQ_ref={dq_ref[coords].item()}") + + diff_dk = (dk - dk_ref).abs() + max_idx = diff_dk.argmax() + coords = torch.unravel_index(max_idx, diff_dk.shape) + print(f"dK max diff: {diff_dk.max().item()}") + print(f" at coordinates {tuple(c.item() for c in coords)}: dK={dk[coords].item()}, dK_ref={dk_ref[coords].item()}") + + diff_dv = (dv - dv_ref).abs() + max_idx = diff_dv.argmax() + coords = torch.unravel_index(max_idx, diff_dv.shape) + print(f"dV max diff: {diff_dv.max().item()}") + print(f" at coordinates {tuple(c.item() for c in coords)}: dV={dv[coords].item()}, dV_ref={dv_ref[coords].item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dq - dq_ref).abs().max().item() <= rtol * ( + dq_pt - dq_ref + ).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dk - dk_ref).abs().max().item() <= rtol * ( + dk_pt - dk_ref + ).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dv - dv_ref).abs().max().item() <= rtol * ( + dv_pt - dv_ref + ).abs().max().item() + dv_atol + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +# @pytest.mark.parametrize("has_learnable_sink", [False]) +# @pytest.mark.parametrize("new_kv", [False, True]) +@pytest.mark.parametrize("new_kv", [False]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) +@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [False]) +# @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) +@pytest.mark.parametrize("has_rotary_seqlens", [False]) +# @pytest.mark.parametrize("rotary_interleaved", [False, True]) +@pytest.mark.parametrize("rotary_interleaved", [True]) +# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) +@pytest.mark.parametrize("rotary_fraction", [0.0]) +@pytest.mark.parametrize("page_size", [None] + ([1, 4, 128])) +# @pytest.mark.parametrize("page_size", [None, 128]) +# @pytest.mark.parametrize("page_size", [128]) +# @pytest.mark.parametrize("has_leftpad", [False, True]) +@pytest.mark.parametrize("has_leftpad", [False]) +# @pytest.mark.parametrize("has_batch_idx", [False, True]) +@pytest.mark.parametrize("has_batch_idx", [False]) +@pytest.mark.parametrize("varlen_q", [False, True]) +# @pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("d", [64, 128]) +# @pytest.mark.parametrize("d", [192]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + # # (1, 128 * 1024), + # # (16, 128 * 1024), + # (128, 128), + # (256, 512), # To test appending KV with more than 1 block + # (2048, 3577), # Enough tile to test persistent scheduler + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +def test_flash_attn_kvcache( + seqlen_q, + seqlen_k, + d, + varlen_q, + has_batch_idx, + has_leftpad, + page_size, + rotary_fraction, + rotary_interleaved, + has_rotary_seqlens, + seqlen_new_eq_seqlen_q, + causal, + local, + new_kv, + has_learnable_sink, + mha_type, + dtype, +): + if page_size is not None and seqlen_k % page_size != 0: + pytest.skip() + if page_size is not None and IS_SM90: + pytest.xfail("paged KV not supported on SM90") + if seqlen_q > seqlen_k and new_kv: + pytest.skip() + if not new_kv and rotary_fraction > 0.0: + pytest.skip() + if rotary_fraction == 0.0 and has_rotary_seqlens: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 5 + # batch_size = 1 + batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 + nheads = 6 + # nheads = 1 + # rotary_dim must be a multiple of 16, and must be <= d + rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + assert nheads % nheads_k == 0 + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + dv_vals = [d] + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) else [0] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + # has_qv = d == 64 and dv >= 256 + has_qv = False + q = ( + torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + if has_qv: + qv = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv = None + if varlen_q: + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random" + ) + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input( + q, query_padding_mask + ) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + qv_unpad = ( + rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None + ) + else: + query_padding_mask = None + q_unpad = q + qv_unpad = qv + cu_seqlens_q, max_seqlen_q = None, None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = ( + (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + ) + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + learnable_sink = None + + seqlen_new = ( + seqlen_q + if seqlen_new_eq_seqlen_q + else torch.randint(1, seqlen_q + 1, (1,)).item() + ) + cu_seqlens_k_new = None + key_new_padding_mask = None + if new_kv: + k = ( + torch.randn( + batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + v = ( + torch.randn( + batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + if varlen_q: # k & v are also varlen + key_new_padding_mask = generate_random_padding_mask( + seqlen_new, batch_size, device, mode="random" + ) + k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input( + k, key_new_padding_mask + ) + v_unpad, *rest = unpad_input(v, key_new_padding_mask) + else: + k_unpad, v_unpad = k, v + else: + k, v, k_unpad, v_unpad = None, None, None, None + if page_size is None: + k_cache = ( + torch.randn( + batch_size_cache, + seqlen_k, + nheads_k, + d, + device=device, + dtype=dtype_ref, + ) + .to(dtype) + .to(dtype_ref) + ) + v_cache = ( + torch.randn( + batch_size_cache, + seqlen_k, + nheads_k, + dv, + device=device, + dtype=dtype_ref, + ) + .to(dtype) + .to(dtype_ref) + ) + page_table = None + else: + ( + k_cache, + v_cache, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, + page_size, + batch_size_cache, + nheads_k, + d, + dv, + device, + dtype, + dtype_ref, + ) + cache_seqlens = torch.randint( + 0 if new_kv else 1, + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough + ( + ( + seqlen_k + - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + + 1 + ) + if new_kv + else (seqlen_k + 1) + ), + (batch_size,), + dtype=torch.int32, + device=device, + ) + if has_leftpad: + cache_leftpad = torch.cat( + [ + torch.randint( + 0, + cache_seqlens[i].item(), + (1,), + dtype=torch.int32, + device=device, + ) + if cache_seqlens[i].item() > 0 + else torch.zeros(1, dtype=torch.int32, device=device) + for i in range(batch_size) + ] + ) + else: + cache_leftpad = None + if has_batch_idx: + cache_batch_idx = torch.randperm( + batch_size_cache, dtype=torch.int32, device=device + )[:batch_size] + else: + cache_batch_idx = None + arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + if not new_kv: + key_padding_mask = arange < cache_seqlens_expanded + else: + k_new_seqlens = ( + key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new + ) + key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens + if has_leftpad: + key_padding_mask = torch.logical_and( + key_padding_mask, + arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k), + ) + # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) + rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 + if rotary_dim > 0: + angle = ( + torch.rand( + seqlen_k if page_size is None else num_blocks * page_size, + rotary_dim // 2, + device=device, + ) + * 2 + * math.pi + ) + cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + if causal or local: + q_ro = apply_rotary_emb( + q, + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=seqlen_q, + ) + # q_ro = q + k_ro = apply_rotary_emb( + k, + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ) + else: + cos, sin = None, None + q_ro, k_ro = q, k + # k_cache[:, 64:] = -1 + k_cache_ref = ( + k_cache if not has_batch_idx else k_cache[cache_batch_idx] + ).clone() + v_cache_ref = ( + v_cache if not has_batch_idx else v_cache[cache_batch_idx] + ).clone() + if new_kv: + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, + arange < cache_seqlens_expanded + k_new_seqlens, + ) + k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") + v_to_update = rearrange(v, "b s ... -> (b s) ...") + if varlen_q: + k_to_update = k_to_update[indices_k] + v_to_update = v_to_update[indices_k] + k_cache_ref[update_mask] = k_to_update + v_cache_ref[update_mask] = v_to_update + k_cache_rep = repeat( + k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k + ) + v_cache_rep = repeat( + v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k + ) + out_ref, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + learnable_sink=learnable_sink, + attention_chunk=attention_chunk, + key_leftpad=cache_leftpad, + ) + out_pt, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + learnable_sink=learnable_sink, + attention_chunk=attention_chunk, + upcast=False, + reorder_ops=True, + key_leftpad=cache_leftpad, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + q = q.to(dtype) + q_unpad = q_unpad.to(dtype) if varlen_q else None + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None + v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None + k = k.to(dtype) if k is not None else None + v = v.to(dtype) if v is not None else None + k_unpad = k_unpad.to(dtype) if k_unpad is not None else None + v_unpad = v_unpad.to(dtype) if v_unpad is not None else None + qv = qv.to(dtype) if qv is not None else None + qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None + cos = cos.to(dtype) if cos is not None else None + sin = sin.to(dtype) if sin is not None else None + k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() + v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() + # num_splits_vals = [1, 0] + # SplitKV is not supported for hdim >= 192 + num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1] + # precompute_metadata_vals = [False, True] + precompute_metadata_vals = [False] + for num_splits, precompute_metadata in itertools.product( + num_splits_vals, precompute_metadata_vals + ): + # SplitKV not supported on SM90 - skip this iteration + if IS_SM90 and num_splits > 1: + continue + # if precompute_metadata: + # scheduler_metadata = get_scheduler_metadata( + # batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, + # cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, + # cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, + # max_seqlen_k_new=seqlen_new, page_size=page_size, + # causal=causal, window_size=window_size, attention_chunk=attention_chunk, + # num_splits=num_splits + # ) + # else: + # scheduler_metadata = None + scheduler_metadata = None + # Repeat to test metadata reuse + for _ in range(1 if not precompute_metadata else 2): + if page_size is None: + k_cache.copy_(k_cache_saved) + v_cache.copy_(v_cache_saved) + else: + k_cache_paged.copy_(k_cache_saved) + v_cache_paged.copy_(v_cache_saved) + # out, lse, *rest = flash_attn_with_kvcache( + out, lse, *rest = flash_attn_varlen_func( + q if not varlen_q else q_unpad, + k_cache if page_size is None else k_cache_paged, + v_cache if page_size is None else v_cache_paged, + # k if not new_kv or not varlen_q else k_unpad, + # v if not new_kv or not varlen_q else v_unpad, + # qv=qv if not varlen_q else qv_unpad, + # rotary_cos=cos, + # rotary_sin=sin, + seqused_k=cache_seqlens, + # cache_batch_idx=cache_batch_idx, + # cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + # cu_seqlens_k_new=cu_seqlens_k_new, + # rotary_seqlens=rotary_seqlens, + causal=causal, + window_size=window_size, + learnable_sink=learnable_sink, + # attention_chunk=attention_chunk, + # rotary_interleaved=rotary_interleaved, + # scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + # return_softmax_lse=True + ) + if varlen_q: + out = output_pad_fn(out) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) + if not has_batch_idx + else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) + if not has_batch_idx + else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[ + ( + page_table + if not has_batch_idx + else page_table[cache_batch_idx] + ).flatten() + ], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[ + ( + page_table + if not has_batch_idx + else page_table[cache_batch_idx] + ).flatten() + ], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) + if dtype is not torch.float8_e4m3fn: + assert torch.equal(v_cache_select, v_cache_ref) + else: + assert torch.allclose( + v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3 + ) + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: + assert torch.equal(k_cache_select, k_cache_ref) + else: + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose( + k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3 + ) + else: + assert torch.allclose( + k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1 + ) + mult = 4 if dtype == torch.float8_e4m3fn else 2 + assert (out - out_ref).abs().max().item() <= mult * ( + out_pt - out_ref + ).abs().max().item() + 1e-5 + mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + assert (out - out_ref).abs().mean().item() <= mult_mean * ( + out_pt - out_ref + ).abs().mean().item() + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("seqlen_q,seqlen_k", [(128, 128), (256, 256)]) +def test_flash_attn_bwd_preallocated_outputs(seqlen_q, seqlen_k, d, causal, dtype): + if IS_SM90 and d == 64 and not causal: + pytest.xfail("SM90 backward: d=64 + non-causal has invalid MMA tile config (m_block=80)") + + from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd + + device = "cuda" + torch.random.manual_seed(42) + batch_size = 2 + nheads = 4 + + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + + out, lse = _flash_attn_fwd(q, k, v, causal=causal, return_lse=True) + dout = torch.randn_like(out) + + dq_ref, dk_ref, dv_ref = _flash_attn_bwd(q, k, v, out, dout, lse, causal=causal) + + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dq_out, dk_out, dv_out = _flash_attn_bwd( + q, k, v, out, dout, lse, causal=causal, dq=dq, dk=dk, dv=dv + ) + + assert dq_out is dq + assert dk_out is dk + assert dv_out is dv + assert torch.allclose(dq, dq_ref, atol=1e-5, rtol=1e-5) + assert torch.allclose(dk, dk_ref, atol=1e-5, rtol=1e-5) + assert torch.allclose(dv, dv_ref, atol=1e-5, rtol=1e-5) + + +def _generate_block_kvcache( + seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref +): + num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 + k_cache_paged = ( + torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + v_cache_paged = ( + torch.randn(num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref) + .to(dtype) + .to(dtype_ref) + ) + page_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + k_cache = rearrange( + k_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + v_cache = rearrange( + v_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks + + +def attention_combine_ref(out_partial, lse_partial): + """ + out_partial: (num_splits, batch_size, seqlen, nheads, d) + lse_partial: (num_splits, batch_size, seqlen, nheads) + """ + lse = torch.logsumexp(lse_partial, dim=0) + scale = torch.exp(lse_partial - lse) + scale = torch.where( + torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale + ) + out = (scale.unsqueeze(-1) * out_partial).sum(0) + return out, lse + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float32]) +# @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +@pytest.mark.parametrize("d", [64, 96, 128, 192, 256, 512]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024]) +# @pytest.mark.parametrize("seqlen", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192]) +# @pytest.mark.parametrize("seqlen", [15]) +@pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 133]) +# @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 11]) +# @pytest.mark.parametrize("num_splits", [11]) +def test_flash_attn_combine(num_splits, seqlen, d, dtype): + device = "cuda" + # set seed + torch.random.manual_seed(1) + batch_size = 5 + nheads = 16 + # batch_size = 1 + # nheads = 1 + # Create tensors in the expected format: (num_splits, batch_size, seqlen, nheads, d) and (num_splits, batch_size, seqlen, nheads) + out_partial = torch.randn( + num_splits * 2, + batch_size, + nheads, + seqlen, + d, + device=device, + dtype=torch.float32, + ).transpose(2, 3)[:num_splits] # To test non-contiguous tensor + lse_partial = torch.randn( + num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32 + ).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor + # To test short-circuiting based on num_splits + lse_partial[num_splits // 2 :, : batch_size // 3] = -float("inf") + + # Test with LSE returned (default behavior) + out, lse = flash_attn_combine( + out_partial, lse_partial, out_dtype=dtype, return_lse=True + ) + out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) + out_pt = out_ref.to(dtype) + + print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + print(f"LSE mean diff: {(lse - lse_ref).abs().mean().item()}") + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) + multiple = 2 + assert ( + (out - out_ref).abs().max().item() + <= multiple * (out_pt - out_ref).abs().max().item() + ) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) + + # Test with LSE not returned + out_no_lse, lse_no_lse = flash_attn_combine( + out_partial, lse_partial, out_dtype=dtype, return_lse=False + ) + assert lse_no_lse is None, "LSE should be None when return_lse=False" + assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), ( + "Output should be the same regardless of return_lse" + ) diff --git a/tests/cute/test_flash_attn_race_condition.py b/tests/cute/test_flash_attn_race_condition.py new file mode 100644 index 00000000000..c2a649067bf --- /dev/null +++ b/tests/cute/test_flash_attn_race_condition.py @@ -0,0 +1,776 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + +import math +import itertools +import os + +import pytest +import torch + +from einops import rearrange, repeat + +try: + from flash_attn.layers.rotary import apply_rotary_emb +except ImportError: + apply_rotary_emb = None + +from flash_attn.cute.testing import ( + attention_ref, + generate_qkv, + generate_random_padding_mask, + pad_input, + unpad_input, +) +from flash_attn.cute.interface import ( + flash_attn_func, + flash_attn_varlen_func, + flash_attn_combine, + _flash_attn_bwd, +) + + +DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" +IS_SM90 = torch.cuda.get_device_capability()[0] == 9 +INCREASED_TRIALS = False + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("mha_type", ["gqa"]) +# @pytest.mark.parametrize("has_learnable_sink", [False, True]) +@pytest.mark.parametrize("has_learnable_sink", [False]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [True]) +# @pytest.mark.parametrize("softcap", [0.0, 15.0]) +@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) +# @pytest.mark.parametrize("local_enum", [0]) +# @pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("d", [64, 128]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (4224, 4224), + (2000, 4000), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +def test_flash_attn_output( + seqlen_q, + seqlen_k, + d, + causal, + local_enum, + softcap, + deterministic, + has_qv, + has_learnable_sink, + mha_type, + dtype, +): + local = local_enum > 0 + if local and causal: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(0) + torch.cuda.empty_cache() + torch.cuda.synchronize() + batch_size = 9 if seqlen_k <= 2048 else 2 + # batch_size = 1 + nheads = 6 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + dv_vals = [d] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = q_ref * softcap / 4 + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + if has_qv: + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = ( + (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + ) + if local_enum == 2: + window_size = (None, -window_size[1]) + elif local_enum == 3: + window_size = (-window_size[0], None) + if local: + print("window size = ", window_size) + # window_size = (-1, -1) if not local else (16, 0) + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + learnable_sink = None + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + # k_extended = repeat(k_ref, "b s h d -> b s (h k) d", k=nheads // nheads_kv) + # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_extended).float() + # # if qv is not None: + # # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # exp_sum = s_tmp.sum(-1) + # # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) + # # lse_ref = torch.logsumexp(qk, dim=-1) + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # num_splits_vals = [1, 3] + # pack_gqa_vals = [False, True, None] + # SplitKV is not supported for hdim >= 192 + pack_gqa_vals = [False] + # num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1] + num_splits_vals = [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out, lse = flash_attn_func( + q, + k, + v, + causal=causal, + # qv=qv, + # q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + # attention_chunk=attention_chunk, + softcap=softcap, + learnable_sink=learnable_sink, + pack_gqa=pack_gqa, + num_splits=num_splits, + deterministic=deterministic, + ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol + + if ( + dtype != torch.float8_e4m3fn + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + and softcap == 0.0 + and dv == d + and learnable_sink is None + # and False + ): + if IS_SM90 and mha_type != "mha": + pytest.xfail("SM90 backward: GQA/MQA has tensor layout issue (qhead_per_kvhead > 1)") + if IS_SM90 and local: + pytest.xfail("SM90 backward: local attention not supported yet") + g = torch.randn_like(out) + # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + # breakpoint() + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad( + out_ref, (q_ref, k_ref, v_ref), g + ) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dq - dq_ref).abs().max().item() <= rtol * ( + dq_pt - dq_ref + ).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dk - dk_ref).abs().max().item() <= rtol * ( + dk_pt - dk_ref + ).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dv - dv_ref).abs().max().item() <= rtol * ( + dv_pt - dv_ref + ).abs().max().item() + dv_atol + + num_iters = 10_000 if INCREASED_TRIALS else 1000 + for i in range(num_iters): + dq2, dk2, dv2, = _flash_attn_bwd( + q, k, v, out, g, lse, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + deterministic=True, + ) + + diff_dq = (dq - dq2).abs() + max_idx = diff_dq.argmax() + print(f"dQ max diff: {diff_dq.max().item()}") + print(f" at index {max_idx.item()}: dQ={dq.flatten()[max_idx].item()}, dQ2={dq2.flatten()[max_idx].item()}") + + diff_dk = (dk - dk2).abs() + max_idx = diff_dk.argmax() + print(f"dK max diff: {diff_dk.max().item()}") + print(f" at index {max_idx.item()}: dK={dk.flatten()[max_idx].item()}, dK2={dk2.flatten()[max_idx].item()}") + + diff_dv = (dv - dv2).abs() + max_idx = diff_dv.argmax() + print(f"dV max diff: {diff_dv.max().item()}") + print(f" at index {max_idx.item()}: dV={dv.flatten()[max_idx].item()}, dV2={dv2.flatten()[max_idx].item()}") + + # print(f"dQ max diff with myself: {(dq - dq2).abs().max().item()}") + # print(f"dK max diff with myself: {(dk - dk2).abs().max().item()}") + # print(f"dV max diff with myself: {(dv - dv2).abs().max().item()}") + # print(f"dQ mean diff with myself: {(dq - dq2).abs().mean().item()}") + # print(f"dK mean diff with myself: {(dk - dk2).abs().mean().item()}") + # print(f"dV mean diff with myself: {(dv - dv2).abs().mean().item()}") + + assert torch.equal(dq, dq2) + assert torch.equal(dk, dk2) + assert torch.equal(dv, dv2) + + print(f"✅ Iteration {i} passed!") + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["gqa"]) +# @pytest.mark.parametrize("has_learnable_sink", [False, True]) +@pytest.mark.parametrize("has_learnable_sink", [False]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [True]) +# @pytest.mark.parametrize("softcap", [0.0, 15.0]) +@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) +# @pytest.mark.parametrize("local_enum", [0, 1]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("add_unused_qkv", [False, True]) +@pytest.mark.parametrize("add_unused_qkv", [False]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128]) +# @pytest.mark.parametrize("d", [128, 192]) +@pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1024, 1024), + (2048, 2048), + ], +) +@pytest.mark.parametrize("varlen_mode", ["random", "third", "full"]) +# @pytest.mark.parametrize("varlen_mode", ["random"]) +@pytest.mark.parametrize( + "zero_lengths_q, zero_lengths_k", + [ + (False, False), + (True, False), + (False, True), + (True, True), + ], +) +def test_flash_attn_varlen_output( + seqlen_q, + seqlen_k, + d, + add_unused_qkv, + causal, + local_enum, + softcap, + deterministic, + has_qv, + has_learnable_sink, + mha_type, + dtype, + varlen_mode, + zero_lengths_q, + zero_lengths_k, +): + local = local_enum > 0 + if local and causal: + pytest.skip() + if ( + causal or local + ): # Right now reference only supports causal attention with seqlen_k == seqlen_q + seqlen_k = seqlen_q + device = "cuda" + # set seed + torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) + batch_size = 49 if seqlen_q <= 1024 else 7 + nheads = 6 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) + dv_vals = [d] # override + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + if has_qv: + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = ( + (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + ) + if local_enum == 2: + window_size = (None, window_size[1]) + elif local_enum == 3: + window_size = (window_size[0], None) + if local: + print("window size = ", window_size) + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + learnable_sink = None + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, + batch_size, + device, + mode=varlen_mode, + zero_lengths=zero_lengths_q, + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, + batch_size, + device, + mode=varlen_mode, + zero_lengths=zero_lengths_k, + ) + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask + + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + # query_padding_mask[:] = True + # query_unused_mask = None + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) + + if causal or local: + key_padding_mask = query_padding_mask + + ( + q_unpad, + k_unpad, + v_unpad, + qv_unpad, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + qv, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv( + q, + k, + v, + query_padding_mask, + key_padding_mask, + qv=qv, + kvpacked=False, + query_unused_mask=query_unused_mask, + key_unused_mask=key_unused_mask, + ) + print("cu_seqlens_q = ", cu_seqlens_q) + print("cu_seqlens_k = ", cu_seqlens_k) + q_unpad, k_unpad, v_unpad = [ + x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad) + ] + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + out_unpad, lse = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + # max_seqlen_k, + # seqused_q=seqused_q, + # seqused_k=seqused_k, + max_seqlen_q=seqlen_q, + max_seqlen_k=seqlen_k, + causal=causal, + # qv=qv_unpad, + # q_descale=q_descale, + # k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + # attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + num_splits=1, + pack_gqa=False, + deterministic=deterministic, + ) + out = output_pad_fn(out_unpad) + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol + + if ( + dtype != torch.float8_e4m3fn + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + and dv == d + and not has_learnable_sink + # and False + ): + g_unpad = torch.randn_like(out_unpad) + # do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) + # import flash_attn_3_cuda + # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen( + # g_unpad, + # q_unpad, + # k_unpad, + # v_unpad, + # out_unpad, + # lse, + # None, + # None, + # None, + # cu_seqlens_q, + # cu_seqlens_k, + # None, None, + # max_seqlen_q, + # max_seqlen_k, + # d ** (-0.5), + # causal, + # window_size[0], window_size[1], + # softcap, + # deterministic, + # 0, # sm_margin + # ) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad( + out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad + ) + dq = dq_pad_fn(dq_unpad) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) + if query_unused_mask is not None: + dq.masked_fill_(q_zero_masking, 0.0) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + g = output_pad_fn(g_unpad) + + # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() + # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad( + out_ref, (q_ref, k_ref, v_ref), g + ) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dq - dq_ref).abs().max().item() <= rtol * ( + dq_pt - dq_ref + ).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dk - dk_ref).abs().max().item() <= rtol * ( + dk_pt - dk_ref + ).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dv - dv_ref).abs().max().item() <= rtol * ( + dv_pt - dv_ref + ).abs().max().item() + dv_atol + + num_iters = 10_000 if INCREASED_TRIALS else 1000 + + for i in range(num_iters): + dq_unpad2, dk_unpad2, dv_unpad2 = _flash_attn_bwd( + q_unpad, k_unpad, v_unpad, out_unpad, g_unpad, lse, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + deterministic=True, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=seqlen_q, + max_seqlen_k=seqlen_k, + ) + + diff_dq = (dq_unpad - dq_unpad2).abs() + max_idx = diff_dq.argmax() + if i % 100 == 0: + print(f"dQ max diff: {diff_dq.max().item()}") + print(f" at index {max_idx.item()}: dQ={dq_unpad.flatten()[max_idx].item()}, dQ2={dq_unpad2.flatten()[max_idx].item()}") + + diff_dk = (dk_unpad - dk_unpad2).abs() + max_idx = diff_dk.argmax() + if i % 100 == 0: + print(f"dK max diff: {diff_dk.max().item()}") + print(f" at index {max_idx.item()}: dK={dk_unpad.flatten()[max_idx].item()}, dK2={dk_unpad2.flatten()[max_idx].item()}") + + diff_dv = (dv_unpad - dv_unpad2).abs() + max_idx = diff_dv.argmax() + if i % 100 == 0: + print(f"dV max diff: {diff_dv.max().item()}") + print(f" at index {max_idx.item()}: dV={dv_unpad.flatten()[max_idx].item()}, dV2={dv_unpad2.flatten()[max_idx].item()}") + + assert torch.equal(dq_unpad, dq_unpad2) + assert torch.equal(dk_unpad, dk_unpad2) + assert torch.equal(dv_unpad, dv_unpad2) + + if i % 100 == 0: + print(f"✅ Iteration {i} passed!") \ No newline at end of file diff --git a/tests/cute/test_flash_attn_varlen.py b/tests/cute/test_flash_attn_varlen.py new file mode 100644 index 00000000000..1666a08fb00 --- /dev/null +++ b/tests/cute/test_flash_attn_varlen.py @@ -0,0 +1,315 @@ +import itertools +from typing import Optional +from einops import rearrange +import pytest + +import torch +import torch.nn.functional as F +from flash_attn.cute import flash_attn_varlen_func + +IS_SM90 = torch.cuda.get_device_capability()[0] == 9 + + +@pytest.mark.parametrize("B", [1, 7, 20]) +@pytest.mark.parametrize("H", [1, 4, 6]) +@pytest.mark.parametrize("D", [64, 128]) +@pytest.mark.parametrize("min_seq_len", [1, 32, 128]) +@pytest.mark.parametrize("max_seq_len", [8, 64, 2048]) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("softmax_scale", [None, 0.1]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +def test_varlen( + B, + H, + D, + min_seq_len, + max_seq_len, + causal, + softmax_scale, + dtype, + mha_type, +): + if min_seq_len > max_seq_len: + pytest.skip("Skipping min_seq_len > max_seq_len") + + q, k, v, cu_seqlens_q, cu_seqlens_k, total_q, total_k = generate_varlen_args( + batch_size=B, + n_heads=H, + d_head=D, + min_len=min_seq_len, + max_len=max_seq_len, + mha_type=mha_type, + dtype=dtype + ) + + # SM90 backward pass doesn't support varlen yet + skip_backward = IS_SM90 + + ok = check_varlen_vs_torch_flash( + q, k, v, + cu_seqlens_q, cu_seqlens_k, + total_q=total_q, total_k=total_k, + softmax_scale=softmax_scale, + causal=causal, + mha_type=mha_type, + skip_backward=skip_backward, + ) + assert ok + +def check_varlen_vs_torch_flash( + q, k, v, + cu_seqlens_q=None, + cu_seqlens_k=None, + seqused_q=None, + seqused_k=None, + total_q=None, + total_k=None, + softmax_scale=None, + causal=True, + mha_type='mha', + softcap=0.0, + atol=3e-2, + rtol=3e-2, + skip_backward=False, +): + assert q.requires_grad and k.requires_grad and v.requires_grad, "Set requires_grad=True on inputs" + + def clone_like(t): + c = t.clone().detach().requires_grad_(True) + return c + + q_fa, k_fa, v_fa = map(clone_like, (q, k, v)) + q_t, k_t, v_t = map(clone_like, (q, k, v)) + + if cu_seqlens_q is not None: + cu_seqlens_q_fa = cu_seqlens_q.clone() + cu_seqlens_q_t = cu_seqlens_q.clone() + else: + cu_seqlens_q_fa = None + cu_seqlens_q_t = None + + if cu_seqlens_k is not None: + cu_seqlens_k_fa = cu_seqlens_k.clone() + cu_seqlens_k_t = cu_seqlens_k.clone() + else: + cu_seqlens_k_fa = None + cu_seqlens_k_t = None + + out_fa, lse_fa = flash_attn_varlen_func( + q_fa, k_fa, v_fa, + cu_seqlens_q=cu_seqlens_q_fa, + cu_seqlens_k=cu_seqlens_k_fa, + seqused_q=seqused_q, + seqused_k=seqused_k, + softmax_scale=(1.0 / q.shape[-1]**0.5) if softmax_scale is None else softmax_scale, + causal=causal, + window_size=(None, None), + learnable_sink=None, + softcap=softcap, + pack_gqa=None, + ) + + out_t = torch_flash_ref( + q_t, k_t, v_t, + cu_seqlens_q=cu_seqlens_q_t, + cu_seqlens_k=cu_seqlens_k_t, + seqused_q=seqused_q, + seqused_k=seqused_k, + total_q=total_q, + total_k=total_k, + softmax_scale=softmax_scale, + causal=causal, + mha_type=mha_type, + ) + + + ok_fwd = torch.allclose(out_fa.float(), out_t.float(), atol=atol, rtol=rtol) + if not ok_fwd: + return False + + # Skip backward if not supported (e.g., SM90 varlen) + if skip_backward: + return True + + # Use the same upstream gradient to compare backward paths + grad_out = torch.randn_like(out_fa) + + grad_fa = clone_like(grad_out) + grad_t = clone_like(grad_out) + + # Cute bwd + out_fa.backward(grad_fa, retain_graph=False) + dq_fa, dk_fa, dv_fa = q_fa.grad, k_fa.grad, v_fa.grad + + # Ref bwd + out_t.backward(grad_t, retain_graph=False) + dq_t, dk_t, dv_t = q_t.grad, k_t.grad, v_t.grad + + # mean_ok_q = _stats("dQ", dq_fa, dq_t, atol=atol, rtol=rtol) + # mean_ok_k = _stats("dK", dk_fa, dk_t, atol=atol, rtol=rtol) + # mean_ok_v = _stats("dV", dv_fa, dv_t, atol=atol, rtol=rtol) + + # return mean_ok_q and mean_ok_k and mean_ok_v + + ok_q = torch.allclose(dq_fa.float(), dq_t.float(), atol=atol, rtol=rtol) + ok_k = torch.allclose(dk_fa.float(), dk_t.float(), atol=atol, rtol=rtol) + ok_v = torch.allclose(dv_fa.float(), dv_t.float(), atol=atol, rtol=rtol) + # print(f"Close? dQ={ok_q}, dK={ok_k}, dV={ok_v}") + return ok_q and ok_k and ok_v + +def generate_varlen_args( + batch_size=8, + n_heads=16, + d_head=128, + min_len=32, + max_len=64, + mha_type="mha", + dtype = torch.bfloat16, +): + + torch.manual_seed(0) + device = "cuda" + + assert mha_type in ["mha", "mqa", "gqa"] + + lens_q = torch.randint(low=min_len, high=max_len + 1, size=(batch_size,)) + lens_k = lens_q.clone() + + cu_seqlens_q = torch.cat([torch.zeros(1, dtype=torch.int32), lens_q.cumsum(0)]) + cu_seqlens_k = torch.cat([torch.zeros(1, dtype=torch.int32), lens_k.cumsum(0)]) + + total_q = cu_seqlens_q[-1] + total_k = cu_seqlens_k[-1] + + cu_seqlens_q = cu_seqlens_q.contiguous().to(dtype=torch.int32, device=device) + cu_seqlens_k = cu_seqlens_k.contiguous().to(dtype=torch.int32, device=device) + + if mha_type == "gqa": + H = 3 * n_heads + H_kv = n_heads + elif mha_type == "mha": + H = H_kv = n_heads + else: # MQA + H = n_heads + H_kv = 1 + + d_head_v = d_head + + q = torch.randn(total_q, H, d_head, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(total_k, H_kv, d_head, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(total_k, H_kv, d_head_v, device=device, dtype=dtype, requires_grad=True) + + return q, k, v, cu_seqlens_q, cu_seqlens_k, total_q, total_k + +# Simple for loop over batch dim implementation +def torch_flash_ref( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor = None, + cu_seqlens_k: torch.Tensor = None, + total_q: int = 0, + total_k: int = 0, + softmax_scale: Optional[float] = None, + causal: bool = False, + **kwargs + ): + + """ + q: (total_q, H, d) if cu_seqlens_q is not None, otherwise (B, L, H, d) + k: (total_k, H_kv, d) if cu_seqlens_k is not None, otherwise (B, L, H_kv, d) + v: (total_k, H_kv, d_v) if cu_seqlens_k is not None, otherwise (B, L, H_kv, d_v) + cu_seqlens_q: (B+1,) int32, cumulative + cu_seqlens_k: (B+1,) int32, cumulative + + seqused_q: (B+1,) int32 + seqused_k: (B+1,) int32 + Returns: + out packed like q: (total_q, H, d_v) + """ + + if cu_seqlens_q is not None: + assert cu_seqlens_q.dim() == 1 + assert total_q == q.shape[0] + assert q.dim() == 3 + H = q.shape[1] + B = cu_seqlens_q.shape[0] - 1 + else: + assert q.dim() == 4 + H = q.shape[2] + B = q.shape[0] + + if cu_seqlens_k is not None: + assert cu_seqlens_k.dim() == 1 + assert total_k == k.shape[0] == v.shape[0] + assert k.dim() == v.dim() == 3 + H_kv = k.shape[1] + B_kv = cu_seqlens_k.shape[0] - 1 + else: + assert k.dim() == v.dim() == 4 + assert k.shape[0] == v.shape[0] + H_kv = k.shape[2] + B_kv = k.shape[0] + + d = q.shape[-1] + d_v = v.shape[-1] + + assert H_kv == v.shape[-2] + assert d == k.shape[-1] + assert B == B_kv + + assert q.device == k.device == v.device + assert q.is_floating_point() and k.is_floating_point() and v.is_floating_point() + + device = q.device + dtype = q.dtype + + hcseq_q = cu_seqlens_q.to(device='cpu') + hcseq_k = cu_seqlens_k.to(device='cpu') + + outs = [] + for b in range(B): + if hcseq_q is not None: + q_start, q_end = int(hcseq_q[b]), int(hcseq_q[b+1]) + qb = q[q_start:q_end] + else: + qb = q[b] + + if hcseq_k is not None: + k_start, k_end = int(hcseq_k[b]), int(hcseq_k[b+1]) + kb = k[k_start:k_end] + vb = v[k_start:k_end] + else: + kb = k[b] + vb = v[b] + + qb = qb.permute(1, 0, 2).unsqueeze(0) + kb = kb.permute(1, 0, 2).unsqueeze(0) + vb = vb.permute(1, 0, 2).unsqueeze(0) + + ob = F.scaled_dot_product_attention( + qb, kb, vb, + attn_mask=None, + dropout_p=0.0, + is_causal=causal, + scale=softmax_scale, + enable_gqa=H_kv!=H + ) + + ob = ob.squeeze(0).permute(1, 0, 2).contiguous() + outs.append(ob) + + if cu_seqlens_q is not None: + out = torch.cat(outs, dim=0).to(device=device, dtype=dtype) + else: + out = torch.stack(outs, dim=0).to(device=device, dtype=dtype) + return out + +@torch.no_grad() +def _stats(name, a, b, atol, rtol): + diff = (a - b).float() + mean_abs = diff.abs().mean().item() + mean_rel = (diff.abs().mean() / b.abs().clamp_min(1e-6).mean().item()) + print(f"{name}: mean_abs={mean_abs:.4e}, mean_rel={mean_rel:.4e}, sum_fa={a.sum()}, sum_ref={b.sum()}") + return mean_abs < atol and mean_rel < rtol \ No newline at end of file diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py new file mode 100644 index 00000000000..f830fcb0afb --- /dev/null +++ b/tests/cute/test_mask_mod.py @@ -0,0 +1,1107 @@ +# mask mod test script +# REFACTORED to use _flash_attn_fwd as the kernel entrypoint +# +# Test Organization: +# - test_static_masks: Fast tests for masks that don't need per-seqlen compilation +# (identity, document, block_diagonal, etc.) with comprehensive seqlen coverage +# - test_parameterized_masks: Slower tests for masks that require recompilation per +# seqlen pair (causal, block_causal, sliding_window) with reduced seqlen coverage +# +# Usage: +# pytest test_mask_mod.py::test_static_masks # Run only fast tests +# pytest test_mask_mod.py::test_parameterized_masks # Run only slow tests +# pytest test_mask_mod.py # Run all tests + +import math + +import pytest +import torch +from torch.nn.attention.flex_attention import create_block_mask, flex_attention +import torch.nn.functional as F + +from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd +from flash_attn.cute.block_sparsity import BlockSparseTensorsTorch +from mask_mod_definitions import get_mask_pair, random_doc_id_tensor +COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] + + +@pytest.fixture(autouse=True) +def reset_torch_state(): + """Reset torch dynamo/compile state between tests to avoid state pollution.""" + torch._dynamo.reset() + torch.cuda.empty_cache() + + yield + + torch._dynamo.reset() + torch.cuda.empty_cache() + +def create_tensors( + batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim_v, dtype +): + device = "cuda" + q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype) + k = torch.randn( + batch_size, seqlen_k, nheads_kv, headdim, device=device, dtype=dtype + ) + v = torch.randn( + batch_size, seqlen_k, nheads_kv, headdim_v, device=device, dtype=dtype + ) + out = torch.empty( + batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype + ) + lse = torch.empty(batch_size, nheads, seqlen_q, device=device, dtype=torch.float32) + + return { + "q": q, + "k": k, + "v": v, + "out": out, + "lse": lse, + } + + +def compute_reference_flex_attn(tensors, mask_mod_flex, block_size: tuple[int, int] | None = None): + """Compute reference using flex_attention for custom mask_mods""" + batch_size, seqlen_q, nheads, headdim = tensors["q"].shape + _, seqlen_k, nheads_kv, _ = tensors["k"].shape + + q = tensors["q"].transpose(1, 2) + k = tensors["k"].transpose(1, 2) + v = tensors["v"].transpose(1, 2) + + if nheads != nheads_kv: + repeat_factor = nheads // nheads_kv + k = k.repeat_interleave(repeat_factor, dim=1) + v = v.repeat_interleave(repeat_factor, dim=1) + + scale = 1.0 / math.sqrt(headdim) + + # Handle identity (no masking) case + if mask_mod_flex is None: + out_ref = F.scaled_dot_product_attention(q, k, v, scale=scale) + return out_ref.transpose(1, 2).contiguous() + + block_mask_kwargs = {} + if block_size is not None: + block_mask_kwargs["BLOCK_SIZE"] = block_size + + block_mask = create_block_mask( + mask_mod_flex, + B=batch_size, + H=nheads, + Q_LEN=seqlen_q, + KV_LEN=seqlen_k, + device=q.device, + **block_mask_kwargs, + ) + out_ref = flex_attention(q, k, v, block_mask=block_mask, scale=scale, enable_gqa=True) + return out_ref.transpose(1, 2).contiguous() + + +SEQLEN_PAIRS_COMPREHENSIVE = [ + (1, 1), + (64, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (4096, 4096), + (4224, 4224), +] + +SEQLEN_PAIRS_SMOKE = [ + (128, 128), + (256, 256), + (113, 203), + (1024, 1024), + (128, 8192) +] + + +def _run_mask_test( + seqlen_q, + seqlen_k, + nheads, + kv_mode, + headdim, + dtype, + mask_name, + window_size, + window_left, + window_right, + tile_m, + tile_n, + use_block_sparsity, + needs_backward=False, +): + torch.manual_seed(42) + + if mask_name == "sliding_window": + assert window_size is not None, ( + "window_size must be specified for sliding_window" + ) + if seqlen_q > seqlen_k: + pytest.skip( + f"seqlen_q={seqlen_q} > seqlen_k={seqlen_k} not supported for sliding_window" + ) + + # Determine nheads_kv based on mode + if kv_mode == "mha": + nheads_kv = nheads + pack_gqa = False + elif kv_mode == "gqa": + if COMPUTE_CAPABILITY < 9: + pytest.xfail("pack_gqa requires SM90+") + nheads_kv = nheads // 4 + pack_gqa = True + elif kv_mode == "mqa": + nheads_kv = 1 + pack_gqa = False + else: + raise ValueError(f"Unknown kv_mode: {kv_mode}") + + batch_size = 1 + headdim_v = headdim + + aux_tensors_arg = None + mask_mod_cute, mask_mod_flex = get_mask_pair( + mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size + ) + if mask_name == "document": + doc_len = max(seqlen_q, seqlen_k) + doc_ids = random_doc_id_tensor(nheads, batch_size, doc_len, device="cuda").to( + dtype=torch.int32, device="cuda" + ) + original_flex_mask = mask_mod_flex + + def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids): + return original_flex_mask(b, h, q_idx, kv_idx, doc_ids) + + aux_tensors_arg = [doc_ids] + elif mask_name == "ima": + bias_threshold = (seqlen_k // 4) * 3 + bias = torch.full((seqlen_k,), bias_threshold, dtype=torch.int32, device="cuda") + original_flex_mask = mask_mod_flex + + def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): + return original_flex_mask(b, h, q_idx, kv_idx, bias) + + aux_tensors_arg = [bias] + causal = False + + if causal and seqlen_k < seqlen_q: + pytest.skip("causal masking requires seqlen_k >= seqlen_q") + + tensors = create_tensors( + batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim_v, dtype + ) + + # SM100 uses sparse_tile_m = 2*tile_m to match forward q_stage=2 pipelining + if COMPUTE_CAPABILITY == 10: + sparse_tile_m = 2 * tile_m + else: + sparse_tile_m = tile_m + + block_mask_nheads = 1 if pack_gqa else nheads + bm = create_block_mask( + mask_mod_flex, + batch_size, + block_mask_nheads, + seqlen_q, + seqlen_k, + device="cuda", + BLOCK_SIZE=(sparse_tile_m, tile_n), + ) + ( + _seq_q, + _seq_k, + kv_mask_cnt, + kv_mask_idx, + full_kv_cnt, + full_kv_idx, + q_mask_cnt, + q_mask_idx, + full_q_cnt, + full_q_idx, + *_, + ) = bm.as_tuple() + + # SM90 block-sparse backward expects BlockMask granularity (128, 128) regardless of fwd tiling. + if COMPUTE_CAPABILITY == 9 and use_block_sparsity and (sparse_tile_m, tile_n) != (128, 128): + bm_bwd = create_block_mask( + mask_mod_flex, + batch_size, + nheads, + seqlen_q, + seqlen_k, + device="cuda", + BLOCK_SIZE=(128, 128), + ) + ( + _seq_q, + _seq_k, + _kv_mask_cnt, + _kv_mask_idx, + _full_kv_cnt, + _full_kv_idx, + q_mask_cnt, + q_mask_idx, + full_q_cnt, + full_q_idx, + *_, + ) = bm_bwd.as_tuple() + + softmax_scale = 1.0 / math.sqrt(headdim) + + block_sparse_mask_fwd = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, + mask_block_idx=kv_mask_idx, + full_block_cnt=full_kv_cnt, + full_block_idx=full_kv_idx, + ) if use_block_sparsity else None + + # Backward uses Q-direction (transposed) sparse tensors + block_sparse_mask_bwd = BlockSparseTensorsTorch( + mask_block_cnt=q_mask_cnt, + mask_block_idx=q_mask_idx, + full_block_cnt=full_q_cnt, + full_block_idx=full_q_idx, + ) if use_block_sparsity else None + + out_tuple = _flash_attn_fwd( + q=tensors["q"], + k=tensors["k"], + v=tensors["v"], + out=tensors["out"], + lse=tensors["lse"], + cu_seqlens_q=None, + cu_seqlens_k=None, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=softmax_scale, + causal=causal, + softcap=None, + window_size_left=window_left, + window_size_right=window_right, + learnable_sink=None, + m_block_size=tile_m, + n_block_size=tile_n, + pack_gqa=pack_gqa, + _compute_capability=None, + score_mod=None, + mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_mask_fwd, + return_lse=True, + aux_tensors=aux_tensors_arg, + ) + + out_cute = out_tuple[0] + lse_cute = out_tuple[1] + tensors_fp32 = { + k: v.float() if v.dtype in [torch.float16, torch.bfloat16] else v + for k, v in tensors.items() + } + + block_size = (tile_m, tile_n) + out_ref_fp32 = compute_reference_flex_attn(tensors_fp32, mask_mod_flex, block_size) + out_ref = compute_reference_flex_attn(tensors, mask_mod_flex, block_size) + out_pt = out_ref.clone() + + # Check for invalid values + assert out_cute.shape == out_ref_fp32.shape == out_ref.shape + assert not torch.isnan(out_cute).any() + assert not torch.isnan(out_ref_fp32).any() + assert torch.isfinite(out_cute).all() + assert torch.isfinite(out_ref_fp32).all() + + # Compute numerical tolerance (matching flash attention tests) + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + rtol = 2 + + ref_error = (out_ref - out_ref_fp32).abs().max().item() + pt_error = (out_pt - out_ref_fp32).abs().max().item() + cute_error = (out_cute - out_ref_fp32).abs().max().item() + + mask_desc = f"mask_mod={mask_name}" + if mask_name == "sliding_window" and window_size is not None: + mask_desc += f"(w={window_size})" + + print( + f"\n{mask_desc} @ Q={seqlen_q}, K={seqlen_k}, H={nheads}/{nheads_kv} ({kv_mode}), " + f"D={headdim}, M={tile_m}, N={tile_n}" + ) + print(" Reference implementation: FlexAttention") + print(f" Reference vs FP32: {ref_error:.2e}") + print(f" PyTorch vs FP32: {pt_error:.2e}") + print(f" Kernel vs FP32: {cute_error:.2e}") + print(f" Tolerance: rtol={rtol} * {pt_error:.2e} + {fwd_atol:.2e}") + print(f" Error ratio: {cute_error / max(pt_error, 1e-10):.2f}") + + # Debug: show some sample values if error is large + if cute_error > 1e-2: + print(f" DEBUG: Sample kernel output: {out_cute[0, 0, 0, :5]}") + print(f" DEBUG: Sample reference output: {out_ref_fp32[0, 0, 0, :5]}") + print(f" DEBUG: Max diff location: {(out_cute - out_ref_fp32).abs().argmax()}") + max_diff_idx = (out_cute - out_ref_fp32).abs().argmax() + max_diff_coords = torch.unravel_index(max_diff_idx, out_cute.shape) + print(f" DEBUG: Max diff at coords: {max_diff_coords}") + print(f" DEBUG: Kernel value: {out_cute[max_diff_coords]:.6f}") + print(f" DEBUG: Reference value: {out_ref_fp32[max_diff_coords]:.6f}") + + # Use the same assertion logic as FlashAttention tests + assert cute_error <= rtol * pt_error + fwd_atol, ( + f"Kernel error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" + ) + + if needs_backward: + q = tensors["q"] + k = tensors["k"] + v = tensors["v"] + + # Create grad_out once and reuse + grad_out = torch.randn_like(out_cute) + + # Create block_mask for flex reference + flex_block_mask = create_block_mask( + mask_mod_flex, batch_size, nheads, seqlen_q, seqlen_k, + device="cuda", BLOCK_SIZE=(tile_m, tile_n), + ) + + dq_cute, dk_cute, dv_cute = run_cute_mask_bwd( + q, k, v, out_cute, lse_cute, grad_out, mask_mod_cute, + block_sparse_mask_bwd=block_sparse_mask_bwd, tile_m=tile_m, tile_n=tile_n, + aux_tensors=aux_tensors_arg, + ) + _, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd( + q, k, v, flex_block_mask, grad_out, dtype=torch.float32 + ) + _, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd( + q, k, v, flex_block_mask, grad_out + ) + + # Check for invalid values + assert not torch.isnan(dq_cute).any(), "dQ contains NaN" + assert not torch.isnan(dk_cute).any(), "dK contains NaN" + assert not torch.isnan(dv_cute).any(), "dV contains NaN" + + bwd_rtol = 2 + min_seqlen = min(seqlen_q, seqlen_k) + bwd_atol_floor = 1e-5 if min_seqlen >= 64 else 3e-5 + dq_atol = max(bwd_atol_floor, 2 * (dq_ref_fp32 + 0.3 - 0.3 - dq_ref_fp32).abs().max().item()) + dk_atol = max(bwd_atol_floor, 2 * (dk_ref_fp32 + 0.3 - 0.3 - dk_ref_fp32).abs().max().item()) + dv_atol = max(bwd_atol_floor, 2 * (dv_ref_fp32 + 0.3 - 0.3 - dv_ref_fp32).abs().max().item()) + + dq_ref = dq_ref_fp32.to(dtype) + dk_ref = dk_ref_fp32.to(dtype) + dv_ref = dv_ref_fp32.to(dtype) + + pt_dq_err = (dq_pt - dq_ref).abs().max().item() + pt_dk_err = (dk_pt - dk_ref).abs().max().item() + pt_dv_err = (dv_pt - dv_ref).abs().max().item() + + cute_dq_err = (dq_cute - dq_ref).abs().max().item() + cute_dk_err = (dk_cute - dk_ref).abs().max().item() + cute_dv_err = (dv_cute - dv_ref).abs().max().item() + + print(" Backward comparison:") + print(f" dQ: PT err={pt_dq_err:.2e}, CuTE err={cute_dq_err:.2e}, atol={dq_atol:.2e}") + print(f" dK: PT err={pt_dk_err:.2e}, CuTE err={cute_dk_err:.2e}, atol={dk_atol:.2e}") + print(f" dV: PT err={pt_dv_err:.2e}, CuTE err={cute_dv_err:.2e}, atol={dv_atol:.2e}") + + assert cute_dq_err <= bwd_rtol * pt_dq_err + dq_atol, f"dQ error too large: {cute_dq_err:.2e}" + assert cute_dk_err <= bwd_rtol * pt_dk_err + dk_atol, f"dK error too large: {cute_dk_err:.2e}" + assert cute_dv_err <= bwd_rtol * pt_dv_err + dv_atol, f"dV error too large: {cute_dv_err:.2e}" + + +def test_mask_mod_ima_partial_block(): + _run_mask_test( + seqlen_q=257, + seqlen_k=257, + nheads=1, + kv_mode="mha", + headdim=128, + dtype=torch.bfloat16, + mask_name="ima", + window_size=None, + window_left=None, + window_right=None, + tile_m=128, + tile_n=128, + use_block_sparsity=True, + needs_backward=True, + ) + + +# Q boundary seqlens: NOT multiples of tile_m (128) +# These exercise the fix for is_full_block tiles not masking OOB Q rows in backward +Q_BOUNDARY_SEQLEN_PAIRS = [ + (200, 200), # Last m_block: rows 128-199 valid, 200-255 should be masked + (300, 300), # Last m_block: rows 256-299 valid, 300-383 should be masked + (129, 129), # Just 1 element into second tile + (255, 255), # Just 1 element short of 2 full tiles + (500, 512), # Q boundary only (K aligned) + (512, 500), # K boundary only (Q aligned) + (333, 444), # Both non-aligned +] + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", Q_BOUNDARY_SEQLEN_PAIRS) +@pytest.mark.parametrize("mask_name", ["block_diagonal", "document"]) +def test_q_boundary_masking_block_sparse_bwd(seqlen_q, seqlen_k, mask_name): + """Test Q boundary masking for block-sparse backward pass. + + This test specifically exercises the fix for the bug where Q rows beyond seqlen_q + were not masked in backward pass for is_full_block=True tiles. + + The bug occurred because: + - In forward, apply_mask_sm100 always checks both Q and K bounds + - In backward, apply_mask_sm100_transposed with is_full_block=True only checked K bounds + - Result: partial last m_blocks had unmasked garbage Q rows contributing to gradients + + Key conditions: + - seqlen_q NOT a multiple of tile_m (128): creates partial last m_block + - Block-sparse with mask_mod: exercises is_full_block=True path + - Backward pass: where the bug manifested + """ + _run_mask_test( + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + nheads=4, + kv_mode="mha", + headdim=128, + dtype=torch.bfloat16, + mask_name=mask_name, + window_size=None, + window_left=None, + window_right=None, + tile_m=128, + tile_n=128, + use_block_sparsity=True, + needs_backward=True, + ) + + +@pytest.mark.skipif(COMPUTE_CAPABILITY != 10, reason="Test uses SM100 block mask conventions (2*tile_m)") +def test_single_doc_bwd_minimal(): + """Minimal test to isolate single-document backward pass bug. + + This test uses batch=1, nheads=1, and a single document (all same doc_id) + to make debugging easier. The bug manifests as large numerical errors + in dQ, dK, dV when blocks are classified as "full blocks" due to + the mask returning True for all positions. + + Run with: pytest tests/cute/test_mask_mod.py::test_single_doc_bwd_minimal -v -s + """ + import random + random.seed(42) + torch.manual_seed(42) + + seqlen_q = 384 + seqlen_k = 300 + batch_size = 1 + nheads = 1 + headdim = 128 + tile_m = 128 + tile_n = 128 + dtype = torch.bfloat16 + + # Create single-document doc_ids (all same doc_id = 0) + doc_ids = torch.zeros(batch_size, nheads, max(seqlen_q, seqlen_k), dtype=torch.int32, device="cuda") + + from mask_mod_definitions import get_mask_pair + mask_mod_cute, mask_mod_flex = get_mask_pair("document", seqlen_q=seqlen_q, seqlen_k=seqlen_k) + + original_flex_mask = mask_mod_flex + def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids): + return original_flex_mask(b, h, q_idx, kv_idx, doc_ids) + + aux_tensors_arg = [doc_ids] + + # Create tensors + q = torch.randn(batch_size, seqlen_q, nheads, headdim, device="cuda", dtype=dtype) + k = torch.randn(batch_size, seqlen_k, nheads, headdim, device="cuda", dtype=dtype) + v = torch.randn(batch_size, seqlen_k, nheads, headdim, device="cuda", dtype=dtype) + out = torch.empty(batch_size, seqlen_q, nheads, headdim, device="cuda", dtype=dtype) + lse = torch.empty(batch_size, nheads, seqlen_q, device="cuda", dtype=torch.float32) + + sparse_tile_m = 2 * tile_m + bm = create_block_mask( + mask_mod_flex, batch_size, nheads, seqlen_q, seqlen_k, + device="cuda", BLOCK_SIZE=(sparse_tile_m, tile_n), + ) + ( + _seq_q, _seq_k, + kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, + q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, *_, + ) = bm.as_tuple() + + block_sparse_mask_fwd = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, mask_block_idx=kv_mask_idx, + full_block_cnt=full_kv_cnt, full_block_idx=full_kv_idx, + ) + block_sparse_mask_bwd = BlockSparseTensorsTorch( + mask_block_cnt=q_mask_cnt, mask_block_idx=q_mask_idx, + full_block_cnt=full_q_cnt, full_block_idx=full_q_idx, + ) + + + out_tuple = _flash_attn_fwd( + q=q, k=k, v=v, out=out, lse=lse, + cu_seqlens_q=None, cu_seqlens_k=None, + seqused_q=None, seqused_k=None, page_table=None, + causal=False, softcap=None, + window_size_left=-1, window_size_right=-1, + m_block_size=tile_m, n_block_size=tile_n, pack_gqa=False, + _compute_capability=None, score_mod=None, + mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_mask_fwd, + return_lse=True, aux_tensors=aux_tensors_arg, + ) + out_cute = out_tuple[0] + lse_cute = out_tuple[1] + + # Backward pass + grad_out = torch.randn_like(out_cute) + + dq_cute, dk_cute, dv_cute = run_cute_mask_bwd( + q, k, v, out_cute, lse_cute, grad_out, mask_mod_cute, + block_sparse_mask_bwd=block_sparse_mask_bwd, + tile_m=tile_m, tile_n=tile_n, + aux_tensors=aux_tensors_arg, + ) + + flex_block_mask = create_block_mask( + mask_mod_flex, batch_size, nheads, seqlen_q, seqlen_k, + device="cuda", BLOCK_SIZE=(tile_m, tile_n), + ) + out_ref, dq_ref, dk_ref, dv_ref = run_flex_reference_bwd( + q, k, v, flex_block_mask, grad_out, dtype=torch.float32 + ) + + # Compare + dq_err = (dq_cute - dq_ref.to(dtype)).abs().max().item() + dk_err = (dk_cute - dk_ref.to(dtype)).abs().max().item() + dv_err = (dv_cute - dv_ref.to(dtype)).abs().max().item() + + print(f"dQ error: {dq_err:.2e}") + print(f"dK error: {dk_err:.2e}") + print(f"dV error: {dv_err:.2e}") + + # Assert gradients are correct (this will fail, demonstrating the bug) + assert dq_err < 0.05, f"dQ error too large: {dq_err:.2e}" + assert dk_err < 0.05, f"dK error too large: {dk_err:.2e}" + assert dv_err < 0.05, f"dV error too large: {dv_err:.2e}" + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS_COMPREHENSIVE) +@pytest.mark.parametrize("nheads", [16]) +@pytest.mark.parametrize("kv_mode", ["mha", "gqa", "mqa"]) +@pytest.mark.parametrize("headdim", [128]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("use_block_sparsity", [True, False]) +@pytest.mark.parametrize( + "mask_name", + ["block_diagonal", "mini_causal"], +) +@pytest.mark.parametrize("tile_m,tile_n", [(128, 128), (128, 112)]) +def test_static_masks( + seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, use_block_sparsity, mask_name, tile_m, tile_n +): + """Test static masks that don't require recompilation per seqlen pair. + + Known good masks: + - block_diagonal: Masks by 64-element diagonal blocks + - mini_causal: Local causal within 128-element tiles + """ + if COMPUTE_CAPABILITY == 10 and (tile_m, tile_n) != (128, 128): + pytest.skip("TODO: Non-128x128 tiles currently not supported on SM 10.0. due to TMEM") + + _run_mask_test( + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + nheads=nheads, + kv_mode=kv_mode, + headdim=headdim, + dtype=dtype, + mask_name=mask_name, + window_size=None, + window_left=None, + window_right=None, + tile_m=tile_m, + tile_n=tile_n, + use_block_sparsity=use_block_sparsity, + needs_backward=True, + ) + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS_SMOKE) +@pytest.mark.parametrize("nheads", [16]) +@pytest.mark.parametrize("kv_mode", ["mha", "gqa"]) +@pytest.mark.parametrize("headdim", [128]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("use_block_sparsity", [True, False]) +@pytest.mark.parametrize( + "mask_name,window_size", + [ + ("causal", None), + ("block_causal", None), + ("sliding_window", 128), + ("sliding_window", 256), + ("sliding_window", 512), + ("document", None), + ], +) +@pytest.mark.parametrize("tile_m,tile_n", [(128, 128), (128, 112), (64, 128)]) +def test_parameterized_masks( + seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, use_block_sparsity, mask_name, window_size, tile_m, tile_n +): + """Test parameterized masks that require recompilation per seqlen pair. + + Uses fewer seqlen combinations to reduce test time. + + Masks tested: + - causal, block_causal: Require offset = seqlen_k - seqlen_q + - sliding_window: Requires window size and offset parameters + - document: Slower to check + """ + if COMPUTE_CAPABILITY == 10 and (tile_m, tile_n) != (128, 128): + pytest.skip("TODO: Non-128x128 tiles currently not supported on SM 10.0. due to TMEM") + + _run_mask_test( + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + nheads=nheads, + kv_mode=kv_mode, + headdim=headdim, + dtype=dtype, + mask_name=mask_name, + window_size=window_size, + window_left=None, + window_right=None, + tile_m=tile_m, + tile_n=tile_n, + use_block_sparsity=use_block_sparsity, + needs_backward=True, + ) + + +def test_sm100_block_sparse_sink_all_masked(): + """Block-sparse regression for the sink path""" + if torch.cuda.get_device_capability()[0] != 10: + pytest.skip("SM100-only test") + device = "cuda" + dtype = torch.bfloat16 + batch_size = 1 + seqlen_q = 256 + seqlen_k = 128 + nheads = 8 + headdim = 128 + q = torch.randn(batch_size, seqlen_q, nheads, headdim, dtype=dtype, device=device) + k = torch.randn(batch_size, seqlen_k, nheads, headdim, dtype=dtype, device=device) + v = torch.randn(batch_size, seqlen_k, nheads, headdim, dtype=dtype, device=device) + learnable_sink = torch.full((nheads,), 0.5, dtype=torch.bfloat16, device=device) + zero_cnt = torch.zeros((batch_size, nheads, 1), dtype=torch.int32, device=device) + zero_idx = torch.zeros((batch_size, nheads, 1, 1), dtype=torch.int32, device=device) + sparse = BlockSparseTensorsTorch( + mask_block_cnt=zero_cnt, + mask_block_idx=zero_idx, + full_block_cnt=zero_cnt, + full_block_idx=zero_idx, + ) + softmax_scale = 1.0 / math.sqrt(headdim) + _, lse = _flash_attn_fwd( + q=q, + k=k, + v=v, + softmax_scale=softmax_scale, + causal=False, + window_size_left=None, + window_size_right=None, + learnable_sink=learnable_sink, + m_block_size=128, + n_block_size=128, + num_threads=384, + pack_gqa=False, + block_sparse_tensors=sparse, + return_lse=True, + ) + # Fully masked tile ⇒ probability mass sits entirely on the sink, so LSE equals sink logit. + expected = learnable_sink.float()[None, :, None].expand_as(lse) + assert torch.allclose(lse, expected, atol=0.0, rtol=0.0) + + +# ============================================================================= +# Backward Helper Functions +# ============================================================================= + +def run_cute_mask_bwd( + q, k, v, out, lse, grad_out, mask_mod_cute, + block_sparse_mask_bwd=None, tile_m=128, tile_n=128, + aux_tensors=None, +): + """Run flash attention backward with mask_mod. + + Args: + q, k, v: Input tensors in BSHD format + out: Forward output tensor + lse: Log-sum-exp from forward pass + grad_out: Gradient of output + mask_mod_cute: CuTE mask modification function + block_sparse_mask_bwd: Block sparse tensors for backward pass + tile_m, tile_n: Tile sizes + aux_tensors: Auxiliary tensors for mask_mod (e.g., doc_ids for document masking) + + Returns (dq, dk, dv) all in BSHD format. + """ + dq, dk, dv = _flash_attn_bwd( + q=q, + k=k, + v=v, + out=out, + dout=grad_out, + lse=lse, + causal=False, + m_block_size=tile_m, + n_block_size=tile_n, + mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_mask_bwd, + aux_tensors=aux_tensors, + ) + + return dq, dk, dv + + +def run_flex_reference_bwd(q, k, v, block_mask, grad_out, dtype=None): + """Run flex_attention forward + backward for reference. + + Args: + q, k, v: Input tensors in BSHD format + block_mask: Pre-created block mask for flex_attention + grad_out: Gradient of output in BSHD format + dtype: Optional dtype to cast inputs to (e.g., torch.float32 for reference) + + Returns (out, dq, dk, dv) all in BSHD format. + """ + # Transpose to BHSD for flex_attention + if dtype is not None: + q_ref = q.transpose(1, 2).to(dtype).requires_grad_(True) + k_ref = k.transpose(1, 2).to(dtype).requires_grad_(True) + v_ref = v.transpose(1, 2).to(dtype).requires_grad_(True) + grad_out_ref = grad_out.transpose(1, 2).to(dtype) + else: + q_ref = q.transpose(1, 2).requires_grad_(True) + k_ref = k.transpose(1, 2).requires_grad_(True) + v_ref = v.transpose(1, 2).requires_grad_(True) + grad_out_ref = grad_out.transpose(1, 2) + + # Use flex_attention directly without torch.compile for backward tests + # torch.compile can hang on certain mask patterns (e.g., mini_causal with float32) + out_ref = flex_attention(q_ref, k_ref, v_ref, block_mask=block_mask, enable_gqa=True) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), grad_out_ref) + + # Transpose back to BSHD + return ( + out_ref.transpose(1, 2), + dq_ref.transpose(1, 2), + dk_ref.transpose(1, 2), + dv_ref.transpose(1, 2), + ) + + +def test_sm90_block_sparse_bwd_mismatched_q_block_granularity_error_message(): + if COMPUTE_CAPABILITY != 9: + pytest.skip("SM90-only test") + + batch_size = 1 + seqlen_q = 256 + seqlen_k = 256 + nheads = 4 + nheads_kv = nheads + headdim = 128 + dtype = torch.bfloat16 + tile_m = 80 + tile_n = 128 + + tensors = create_tensors(batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim, dtype) + mask_mod_cute, mask_mod_flex = get_mask_pair("block_diagonal", seqlen_q=seqlen_q, seqlen_k=seqlen_k) + bm = create_block_mask( + mask_mod_flex, + batch_size, + nheads, + seqlen_q, + seqlen_k, + device="cuda", + BLOCK_SIZE=(tile_m, tile_n), + ) + ( + _seq_q, + _seq_k, + _kv_mask_cnt, + _kv_mask_idx, + _full_kv_cnt, + _full_kv_idx, + q_mask_cnt, + q_mask_idx, + full_q_cnt, + full_q_idx, + *_, + ) = bm.as_tuple() + + block_sparse_mask_bwd = BlockSparseTensorsTorch( + mask_block_cnt=q_mask_cnt, + mask_block_idx=q_mask_idx, + full_block_cnt=full_q_cnt, + full_block_idx=full_q_idx, + ) + + softmax_scale = 1.0 / math.sqrt(headdim) + out = torch.empty(batch_size, seqlen_q, nheads, headdim, device="cuda", dtype=dtype) + lse = torch.empty(batch_size, nheads, seqlen_q, device="cuda", dtype=torch.float32) + grad_out = torch.randn_like(out) + + with pytest.raises( + ValueError, + match=r"Hint: Backward expects Q-direction block-sparse tensors.*BLOCK_SIZE=\(128, 128\)", + ): + _flash_attn_bwd( + q=tensors["q"], + k=tensors["k"], + v=tensors["v"], + out=out, + dout=grad_out, + lse=lse, + softmax_scale=softmax_scale, + causal=False, + m_block_size=tile_m, + n_block_size=tile_n, + mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_mask_bwd, + ) + + +def test_gqa_block_sparse_broadcast_pattern_recompilation(): + """Test that different block sparse broadcast patterns trigger recompilation. + + This is a regression test for a bug where: + 1. First call with block_mask H=1 (broadcasts across all query heads) + 2. Second call with block_mask H=nheads (no broadcast) + 3. Second call incorrectly reused cached kernel from first call + + The fix adds block_sparse_broadcast_pattern to the compile key so that + kernels are recompiled when broadcast patterns change. CuTe's + mark_layout_dynamic() keeps stride=0 as static, so different broadcast + patterns require different compiled kernels. + """ + torch.manual_seed(42) + + batch_size = 2 + nheads = 8 + nheads_kv = 2 + seqlen = 257 + headdim = 64 + dtype = torch.bfloat16 + tile_m = 128 + tile_n = 128 + + sparse_tile_m = 2 * tile_m if COMPUTE_CAPABILITY == 10 else tile_m + + def causal_mask(b, h, q, kv): + return q >= kv + + mask_mod_cute, _ = get_mask_pair("causal", seqlen_q=seqlen, seqlen_k=seqlen) + + tensors = create_tensors(batch_size, seqlen, seqlen, nheads, nheads_kv, headdim, headdim, dtype) + q, k, v = tensors["q"], tensors["k"], tensors["v"] + grad_out = torch.randn_like(tensors["out"]) + softmax_scale = 1.0 / math.sqrt(headdim) + + def run_with_block_mask_nheads(block_mask_nheads: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bm = create_block_mask( + causal_mask, batch_size, block_mask_nheads, seqlen, seqlen, + device="cuda", BLOCK_SIZE=(sparse_tile_m, tile_n), + ) + ( + _seq_q, _seq_k, + kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, + q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, *_, + ) = bm.as_tuple() + + block_sparse_fwd = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, mask_block_idx=kv_mask_idx, + full_block_cnt=full_kv_cnt, full_block_idx=full_kv_idx, + ) + block_sparse_bwd = BlockSparseTensorsTorch( + mask_block_cnt=q_mask_cnt, mask_block_idx=q_mask_idx, + full_block_cnt=full_q_cnt, full_block_idx=full_q_idx, + ) + + out = torch.empty_like(tensors["out"]) + lse = torch.empty_like(tensors["lse"]) + + out_tuple = _flash_attn_fwd( + q=q, k=k, v=v, out=out, lse=lse, + softmax_scale=softmax_scale, causal=False, + window_size_left=-1, window_size_right=-1, + m_block_size=tile_m, n_block_size=tile_n, pack_gqa=False, + mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_fwd, + return_lse=True, + ) + out_cute, lse_cute = out_tuple[0], out_tuple[1] + + dq, dk, dv = run_cute_mask_bwd( + q, k, v, out_cute, lse_cute, grad_out, mask_mod_cute, + block_sparse_mask_bwd=block_sparse_bwd, tile_m=tile_m, tile_n=tile_n, + ) + return dq, dk, dv + + flex_block_mask = create_block_mask( + causal_mask, batch_size, nheads, seqlen, seqlen, + device="cuda", BLOCK_SIZE=(tile_m, tile_n), + ) + _, dq_ref, dk_ref, dv_ref = run_flex_reference_bwd(q, k, v, flex_block_mask, grad_out, dtype=torch.float32) + dq_ref, dk_ref, dv_ref = dq_ref.to(dtype), dk_ref.to(dtype), dv_ref.to(dtype) + + dq_broadcast, dk_broadcast, dv_broadcast = run_with_block_mask_nheads(1) + dq_no_broadcast, dk_no_broadcast, dv_no_broadcast = run_with_block_mask_nheads(nheads) + + err_broadcast_dq = (dq_broadcast - dq_ref).abs().max().item() + err_no_broadcast_dq = (dq_no_broadcast - dq_ref).abs().max().item() + + print(f"\nGQA block sparse broadcast pattern test:") + print(f" dQ error (H=1 broadcast): {err_broadcast_dq:.2e}") + print(f" dQ error (H={nheads} no broadcast): {err_no_broadcast_dq:.2e}") + + assert err_broadcast_dq < 0.1, f"Broadcast dQ error too large: {err_broadcast_dq:.2e}" + assert err_no_broadcast_dq < 0.1, f"No-broadcast dQ error too large: {err_no_broadcast_dq:.2e}" + + +def test_gqa_expand_stride_zero_bug(): + """Test that GQA with expand()-created K/V tensors works correctly. + + This is a regression test for bugs with expand()-created tensors: + + Forward bug: cute.assume() fails when tensor strides are Python int 0 + (from expand()) instead of MLIR values. + Error: AttributeError: 'int' object has no attribute 'type' + + Backward bug: mark_layout_dynamic fails with expanded tensors. + Error: RuntimeError: Expected strides[leading_dim] == 1, but got N. + + Trigger: expand() + transpose() creates stride=0 dimensions (GQA pattern). + """ + torch.manual_seed(42) + + batch_size = 1 + seqlen = 2048 + headdim = 128 + n_heads = 4 + n_kv_heads = 1 + dtype = torch.bfloat16 + device = "cuda" + + q = torch.randn(batch_size, seqlen, n_heads, headdim, device=device, dtype=dtype) + k_orig = torch.randn(batch_size, seqlen, n_kv_heads, headdim, device=device, dtype=dtype) + v_orig = torch.randn(batch_size, seqlen, n_kv_heads, headdim, device=device, dtype=dtype) + + k = k_orig.expand(batch_size, seqlen, n_heads, headdim) + v = v_orig.expand(batch_size, seqlen, n_heads, headdim) + + assert k.stride()[2] == 0, "K should have stride=0 in head dim from expand()" + assert v.stride()[2] == 0, "V should have stride=0 in head dim from expand()" + + out = torch.empty_like(q) + lse = torch.empty(batch_size, n_heads, seqlen, device=device, dtype=torch.float32) + softmax_scale = 1.0 / math.sqrt(headdim) + + out_tuple = _flash_attn_fwd( + q=q, k=k, v=v, out=out, lse=lse, + softmax_scale=softmax_scale, + causal=True, + m_block_size=128, n_block_size=128, + return_lse=True, + ) + out_fwd, lse_fwd = out_tuple[0], out_tuple[1] + + assert not torch.isnan(out_fwd).any(), "Forward output contains NaN" + assert torch.isfinite(out_fwd).all(), "Forward output contains non-finite values" + + tensors_for_ref = {"q": q, "k": k, "v": v} + tensors_fp32 = {"q": q.float(), "k": k.float(), "v": v.float()} + + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + out_ref = compute_reference_flex_attn(tensors_for_ref, causal_mask) + out_ref_fp32 = compute_reference_flex_attn(tensors_fp32, causal_mask) + + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + rtol = 2 + pt_error = (out_ref - out_ref_fp32).abs().max().item() + cute_error = (out_fwd - out_ref_fp32).abs().max().item() + + print(f"\nGQA expand stride=0 test:") + print(f" Forward: kernel err={cute_error:.2e}, ref err={pt_error:.2e}, atol={fwd_atol:.2e}") + assert cute_error <= rtol * pt_error + fwd_atol, ( + f"Forward error {cute_error:.2e} exceeds {rtol}x ref error {pt_error:.2e} + {fwd_atol:.2e}" + ) + + grad_out = torch.randn_like(out_fwd) + dq, dk, dv = _flash_attn_bwd( + q=q, k=k, v=v, out=out_fwd, dout=grad_out, lse=lse_fwd, + softmax_scale=softmax_scale, + causal=True, + m_block_size=128, n_block_size=128, + ) + + assert not torch.isnan(dq).any(), "dQ contains NaN" + assert not torch.isnan(dk).any(), "dK contains NaN" + assert not torch.isnan(dv).any(), "dV contains NaN" + + flex_block_mask = create_block_mask( + causal_mask, batch_size, n_heads, seqlen, seqlen, + device=device, BLOCK_SIZE=(128, 128), + ) + _, dq_ref, dk_ref, dv_ref = run_flex_reference_bwd(q, k, v, flex_block_mask, grad_out, dtype=torch.float32) + + bwd_rtol = 2 + bwd_atol_floor = 1e-5 + + dq_atol = max(bwd_atol_floor, 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item()) + dk_atol = max(bwd_atol_floor, 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item()) + dv_atol = max(bwd_atol_floor, 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item()) + + _, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, flex_block_mask, grad_out) + + pt_dq_err = (dq_pt - dq_ref.to(dtype)).abs().max().item() + pt_dk_err = (dk_pt - dk_ref.to(dtype)).abs().max().item() + pt_dv_err = (dv_pt - dv_ref.to(dtype)).abs().max().item() + + cute_dq_err = (dq - dq_ref.to(dtype)).abs().max().item() + cute_dk_err = (dk - dk_ref.to(dtype)).abs().max().item() + cute_dv_err = (dv - dv_ref.to(dtype)).abs().max().item() + + print(f" Backward dQ: kernel err={cute_dq_err:.2e}, ref err={pt_dq_err:.2e}, atol={dq_atol:.2e}") + print(f" Backward dK: kernel err={cute_dk_err:.2e}, ref err={pt_dk_err:.2e}, atol={dk_atol:.2e}") + print(f" Backward dV: kernel err={cute_dv_err:.2e}, ref err={pt_dv_err:.2e}, atol={dv_atol:.2e}") + + assert cute_dq_err <= bwd_rtol * pt_dq_err + dq_atol, f"dQ error too large: {cute_dq_err:.2e}" + assert cute_dk_err <= bwd_rtol * pt_dk_err + dk_atol, f"dK error too large: {cute_dk_err:.2e}" + assert cute_dv_err <= bwd_rtol * pt_dv_err + dv_atol, f"dV error too large: {cute_dv_err:.2e}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/cute/test_score_mod.py b/tests/cute/test_score_mod.py new file mode 100644 index 00000000000..11efcc8cdbc --- /dev/null +++ b/tests/cute/test_score_mod.py @@ -0,0 +1,940 @@ +import pytest +import torch +import cutlass +import cutlass.cute as cute +from cutlass._mlir.dialects import math as mlir_math +import operator +from torch.nn.attention.flex_attention import flex_attention +from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd + +COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] + +from score_mod_definitions import ( + # TensorSSA-based score mods + score_mod_identity as score_mod_1, + score_mod_causal as score_mod_2, + score_mod_rel_bias as score_mod_3, + score_mod_rel_bias_x2 as score_mod_4, + score_mod_times_two as score_mod_5, + score_mod_alibi as score_mod_6, + score_mod_sliding_window as score_mod_7, + score_mod_block_diagonal as score_mod_8, + score_mod_causal_v2 as score_mod_9, + score_mod_batch_bias as score_mod_10, + score_mod_dual_buffer as score_mod_11, +) # isort: split +from score_mod_definitions import ( + # Eager (torch) reference score mods + identity_eager, + causal_eager as causal_mask_eager, + rel_bias_eager as relative_bias_eager, + rel_bias_x2_eager as relative_bias_v2_eager, + times_two_eager, + alibi_eager as alibi_bias_eager, + sliding_window_eager, + block_diagonal_eager, + causal_v2_eager as causal_mask_v2_eager, + batch_bias_factory as batch_bias, + dual_buffer_factory as dual_buffer_bias, +) + +COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] + +# Test pairs: (cute_jit_function, eager_reference_function) +TEST_PAIRS = [ + (score_mod_1, None), + (score_mod_2, causal_mask_eager), + (score_mod_3, relative_bias_eager), + (score_mod_4, relative_bias_v2_eager), + (score_mod_5, times_two_eager), + (score_mod_6, alibi_bias_eager), + (score_mod_7, sliding_window_eager), + (score_mod_8, block_diagonal_eager), + (score_mod_9, causal_mask_v2_eager), +] + +# Test pairs with aux_tensors: (cute_jit_function, eager_reference_function_factory) +TEST_PAIRS_WITH_AUX_TENSORS = [ + (score_mod_10, batch_bias), + (score_mod_11, dual_buffer_bias), +] + +SEQLEN_CONFIGS = [ + (1, 1), + (64, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (4096, 4096), + (4224, 4224), +] + + +def create_tensors( + batch_size=2, num_heads=4, seqlen_q=64, seqlen_kv=64, dim=128, dtype=torch.bfloat16 +): + q = torch.randn(batch_size, num_heads, seqlen_q, dim, device="cuda", dtype=dtype) + k = torch.randn(batch_size, num_heads, seqlen_kv, dim, device="cuda", dtype=dtype) + v = torch.randn(batch_size, num_heads, seqlen_kv, dim, device="cuda", dtype=dtype) + return q, k, v + + +def run_cute_flash( + q, k, v, cute_score_mod, aux_tensors=None, pack_gqa=False +) -> torch.Tensor: + q_transposed, k_transposed, v_transposed = map( + lambda x: x.transpose(1, 2), (q, k, v) + ) + out = torch.empty_like(q_transposed) + _flash_attn_fwd( + q_transposed, + k_transposed, + v_transposed, + return_lse=True, + score_mod=cute_score_mod, + out=out, + lse=None, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + ) + return out.transpose(1, 2) + + +def run_flex_reference(q, k, v, eager_score_mod, dtype=None) -> torch.Tensor: + if dtype is not None: + q, k, v = q.to(dtype), k.to(dtype), v.to(dtype) + return flex_attention( + q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1] + ) + + +@pytest.mark.parametrize("seqlen_q,seqlen_kv", SEQLEN_CONFIGS) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 2), (4, 2)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("score_mod_pair", TEST_PAIRS) +def test_cute_vs_flex_attention( + seqlen_q, seqlen_kv, qhead_per_kvhead, num_kv_heads, dtype, score_mod_pair +): + torch.random.manual_seed(42) + cute_score_mod, eager_score_mod = score_mod_pair + + num_q_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + q, k, v = create_tensors( + seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=num_q_heads, dtype=dtype + ) + if pack_gqa: + k = k[:, :num_kv_heads, :, :].clone() + v = v[:, :num_kv_heads, :, :].clone() + + out_ref_fp32 = run_flex_reference(q, k, v, eager_score_mod, dtype=torch.float32) + + out_pt = run_flex_reference(q, k, v, eager_score_mod) + out_cute = run_cute_flash(q, k, v, cute_score_mod, pack_gqa=pack_gqa) + + # Basic shape and NaN checks + assert out_cute.shape == out_ref_fp32.shape == out_pt.shape + assert not torch.isnan(out_cute).any() + assert not torch.isnan(out_ref_fp32).any() + assert not torch.isnan(out_pt).any() + assert torch.isfinite(out_cute).all() + assert torch.isfinite(out_ref_fp32).all() + assert torch.isfinite(out_pt).all() + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + rtol = 2 + + # Calculate actual errors + pt_error = (out_pt - out_ref_fp32).abs().max().item() + cute_error = (out_cute - out_ref_fp32).abs().max().item() + + print(f"\nNumerical comparison for {cute_score_mod.__name__}:") + print(f" PyTorch vs FP32 ref max error: {pt_error:.2e}") + print(f" CuTE vs FP32 ref max error: {cute_error:.2e}") + print(f" Dynamic absolute tolerance: {fwd_atol:.2e}") + print(f" Error ratio (CuTE/PyTorch): {cute_error / max(pt_error, 1e-10):.2f}") + + # Assert that CuTE's error is at most rtol times PyTorch's error + fwd_atol + assert cute_error <= rtol * pt_error + fwd_atol, ( + f"CuTE error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" + ) + + +@pytest.mark.parametrize("seqlen_q,seqlen_kv", SEQLEN_CONFIGS) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("score_mod_pair", TEST_PAIRS_WITH_AUX_TENSORS) +def test_cute_vs_flex_attention_with_aux_tensors( + seqlen_q, seqlen_kv, qhead_per_kvhead, num_kv_heads, dtype, score_mod_pair +): + torch.random.manual_seed(42) + cute_score_mod, eager_score_mod_factory = score_mod_pair + + batch_size = 2 + num_q_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + q, k, v = create_tensors( + batch_size=batch_size, + seqlen_q=seqlen_q, + seqlen_kv=seqlen_kv, + num_heads=num_q_heads, + dtype=dtype, + ) + if pack_gqa: + k = k[:, :num_kv_heads, :, :].clone() + v = v[:, :num_kv_heads, :, :].clone() + + if cute_score_mod == score_mod_10: + buffer = torch.randn(batch_size, device="cuda", dtype=dtype) * 0.1 + aux_tensors = [buffer] + eager_score_mod = eager_score_mod_factory(buffer) + assert buffer.shape == (batch_size,) + elif cute_score_mod == score_mod_11: + head_bias = torch.randn(num_q_heads, device="cuda", dtype=dtype) * 0.2 + pos_scale = torch.arange(seqlen_q, device="cuda", dtype=dtype) * 0.01 + aux_tensors = [head_bias, pos_scale] + eager_score_mod = eager_score_mod_factory(head_bias, pos_scale) + assert head_bias.shape == (num_q_heads,) + assert pos_scale.shape == (seqlen_q,) + + out_ref_fp32 = run_flex_reference(q, k, v, eager_score_mod, dtype=torch.float32) + + out_pt = run_flex_reference(q, k, v, eager_score_mod) + out_cute = run_cute_flash( + q, k, v, cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa + ) + + # Basic shape and NaN checks + assert out_cute.shape == out_ref_fp32.shape == out_pt.shape + assert not torch.isnan(out_cute).any() + assert not torch.isnan(out_ref_fp32).any() + assert not torch.isnan(out_pt).any() + assert torch.isfinite(out_cute).all() + assert torch.isfinite(out_ref_fp32).all() + assert torch.isfinite(out_pt).all() + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + rtol = 2 + + # Calculate actual errors + pt_error = (out_pt - out_ref_fp32).abs().max().item() + cute_error = (out_cute - out_ref_fp32).abs().max().item() + + print(f"\nNumerical comparison for {cute_score_mod.__name__}:") + print(f" PyTorch vs FP32 ref max error: {pt_error:.2e}") + print(f" CuTE vs FP32 ref max error: {cute_error:.2e}") + print(f" Dynamic absolute tolerance: {fwd_atol:.2e}") + print(f" Error ratio (CuTE/PyTorch): {cute_error / max(pt_error, 1e-10):.2f}") + + # Assert that CuTE's error is at most rtol times PyTorch's error + fwd_atol + assert cute_error <= rtol * pt_error + fwd_atol, ( + f"CuTE error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" + ) + + +def _generate_block_kvcache( + seqlen_k, page_size, batch_size, nheads_k, d, device, dtype +): + import math + from einops import rearrange + + num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 + k_cache_paged = torch.randn( + num_blocks, page_size, nheads_k, d, device=device, dtype=dtype + ) + v_cache_paged = torch.randn( + num_blocks, page_size, nheads_k, d, device=device, dtype=dtype + ) + page_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + k_cache_bshd = rearrange( + k_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + v_cache_bshd = rearrange( + v_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + k_cache = k_cache_bshd.transpose(1, 2) + v_cache = v_cache_bshd.transpose(1, 2) + return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("page_size", [None, 1, 4, 128]) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 2), (4, 2)]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_kv", + [ + (1, 128), + (64, 256), + (64, 800), + (256, 256), + (113, 203), + ], +) +@pytest.mark.parametrize("score_mod_pair", TEST_PAIRS) +@pytest.mark.skipif(COMPUTE_CAPABILITY != 10, reason="Paged KV cache only supported on SM100") +def test_score_mod_with_paged_kvcache( + seqlen_q, + seqlen_kv, + qhead_per_kvhead, + num_kv_heads, + page_size, + dtype, + score_mod_pair, +): + if COMPUTE_CAPABILITY == 9: + pytest.xfail("Paged KV cache only supported on SM100") + if page_size is not None and seqlen_kv % page_size != 0: + pytest.skip() + + torch.random.manual_seed(42) + cute_score_mod, eager_score_mod = score_mod_pair + + batch_size = 2 + num_q_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + dim = 128 + device = "cuda" + + q = torch.randn(batch_size, num_q_heads, seqlen_q, dim, device=device, dtype=dtype) + + if page_size is None: + k_cache = torch.randn( + batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype + ) + v_cache = torch.randn( + batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype + ) + page_table = None + k_cache_paged = None + v_cache_paged = None + else: + ( + k_cache, + v_cache, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_kv, page_size, batch_size, num_kv_heads, dim, device, dtype + ) + + cache_seqlens = torch.randint( + 1, seqlen_kv + 1, (batch_size,), dtype=torch.int32, device=device + ) + + from einops import rearrange + + arange = rearrange(torch.arange(seqlen_kv, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + key_padding_mask = arange < cache_seqlens_expanded + + if pack_gqa: + k_cache_rep = k_cache.repeat_interleave(qhead_per_kvhead, dim=1) + v_cache_rep = v_cache.repeat_interleave(qhead_per_kvhead, dim=1) + else: + k_cache_rep = k_cache + v_cache_rep = v_cache + + def make_masked_score_mod(base_score_mod, seqused_k_tensor): + seqused_k_dev = seqused_k_tensor + + def masked_score_mod(score, b, h, q_idx, kv_idx): + if base_score_mod is not None: + score = base_score_mod(score, b, h, q_idx, kv_idx) + seqlen_limit = torch.gather(seqused_k_dev, 0, b.long()) + valid_mask = kv_idx < seqlen_limit + return torch.where(valid_mask, score, torch.full_like(score, float("-inf"))) + + return masked_score_mod + + masked_score_mod_fp32 = make_masked_score_mod(eager_score_mod, cache_seqlens) + masked_score_mod = make_masked_score_mod(eager_score_mod, cache_seqlens) + + out_ref_fp32 = run_flex_reference( + q, k_cache_rep, v_cache_rep, masked_score_mod_fp32, dtype=torch.float32 + ) + out_pt = run_flex_reference(q, k_cache_rep, v_cache_rep, masked_score_mod) + + q_bshd = q.transpose(1, 2) + out_cute = torch.empty_like(q_bshd) + + if page_size is None: + k_bshd = k_cache.transpose(1, 2) + v_bshd = v_cache.transpose(1, 2) + _flash_attn_fwd( + q_bshd, + k_bshd, + v_bshd, + seqused_k=cache_seqlens, + return_lse=True, + score_mod=cute_score_mod, + out=out_cute, + lse=None, + pack_gqa=pack_gqa, + ) + else: + _flash_attn_fwd( + q_bshd, + k_cache_paged, + v_cache_paged, + seqused_k=cache_seqlens, + page_table=page_table, + return_lse=True, + score_mod=cute_score_mod, + out=out_cute, + lse=None, + pack_gqa=pack_gqa, + ) + + out_cute = out_cute.transpose(1, 2) + + assert out_cute.shape == out_ref_fp32.shape == out_pt.shape + assert not torch.isnan(out_cute).any() + assert not torch.isnan(out_ref_fp32).any() + assert not torch.isnan(out_pt).any() + assert torch.isfinite(out_cute).all() + assert torch.isfinite(out_ref_fp32).all() + assert torch.isfinite(out_pt).all() + + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + rtol = 2 + + pt_error = (out_pt - out_ref_fp32).abs().max().item() + cute_error = (out_cute - out_ref_fp32).abs().max().item() + + print( + f"\nNumerical comparison for {cute_score_mod.__name__} (paged={page_size is not None}):" + ) + print(f" PyTorch vs FP32 ref max error: {pt_error:.2e}") + print(f" CuTE vs FP32 ref max error: {cute_error:.2e}") + print(f" Dynamic absolute tolerance: {fwd_atol:.2e}") + print(f" Error ratio (CuTE/PyTorch): {cute_error / max(pt_error, 1e-10):.2f}") + + assert cute_error <= rtol * pt_error + fwd_atol, ( + f"CuTE error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" + ) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("page_size", [None, 128]) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_kv", + [ + (64, 128), + (128, 256), + (256, 256), + ], +) +@pytest.mark.parametrize("score_mod_pair", TEST_PAIRS_WITH_AUX_TENSORS) +@pytest.mark.skipif(COMPUTE_CAPABILITY != 10, reason="Paged KV cache only supported on SM100") +def test_score_mod_with_paged_kvcache_aux_tensors( + seqlen_q, + seqlen_kv, + qhead_per_kvhead, + num_kv_heads, + page_size, + dtype, + score_mod_pair, +): + if COMPUTE_CAPABILITY == 9: + pytest.xfail("Paged KV cache only supported on SM100") + if page_size is not None and seqlen_kv % page_size != 0: + pytest.skip() + + torch.random.manual_seed(42) + cute_score_mod, eager_score_mod_factory = score_mod_pair + + batch_size = 2 + num_q_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + dim = 128 + device = "cuda" + + q = torch.randn(batch_size, num_q_heads, seqlen_q, dim, device=device, dtype=dtype) + + if page_size is None: + k_cache = torch.randn( + batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype + ) + v_cache = torch.randn( + batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype + ) + page_table = None + k_cache_paged = None + v_cache_paged = None + else: + ( + k_cache, + v_cache, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_kv, page_size, batch_size, num_kv_heads, dim, device, dtype + ) + + cache_seqlens = torch.randint( + 1, seqlen_kv + 1, (batch_size,), dtype=torch.int32, device=device + ) + + if cute_score_mod == score_mod_10: + buffer = torch.randn(batch_size, device=device, dtype=dtype) * 0.1 + aux_tensors = [buffer] + eager_score_mod = eager_score_mod_factory(buffer) + elif cute_score_mod == score_mod_11: + head_bias = torch.randn(num_q_heads, device=device, dtype=dtype) * 0.2 + pos_scale = torch.arange(seqlen_q, device=device, dtype=dtype) * 0.01 + aux_tensors = [head_bias, pos_scale] + eager_score_mod = eager_score_mod_factory(head_bias, pos_scale) + + from einops import rearrange + + arange = rearrange(torch.arange(seqlen_kv, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + key_padding_mask = arange < cache_seqlens_expanded + + if pack_gqa: + k_cache_rep = k_cache.repeat_interleave(qhead_per_kvhead, dim=1) + v_cache_rep = v_cache.repeat_interleave(qhead_per_kvhead, dim=1) + else: + k_cache_rep = k_cache + v_cache_rep = v_cache + + def make_masked_score_mod(base_score_mod, seqused_k_tensor): + seqused_k_dev = seqused_k_tensor + + def masked_score_mod(score, b, h, q_idx, kv_idx): + if base_score_mod is not None: + score = base_score_mod(score, b, h, q_idx, kv_idx) + seqlen_limit = torch.gather(seqused_k_dev, 0, b.long()) + valid_mask = kv_idx < seqlen_limit + return torch.where(valid_mask, score, torch.full_like(score, float("-inf"))) + + return masked_score_mod + + masked_score_mod_fp32 = make_masked_score_mod(eager_score_mod, cache_seqlens) + masked_score_mod = make_masked_score_mod(eager_score_mod, cache_seqlens) + + out_ref_fp32 = run_flex_reference( + q, k_cache_rep, v_cache_rep, masked_score_mod_fp32, dtype=torch.float32 + ) + out_pt = run_flex_reference(q, k_cache_rep, v_cache_rep, masked_score_mod) + + q_bshd = q.transpose(1, 2) + out_cute = torch.empty_like(q_bshd) + + if page_size is None: + k_bshd = k_cache.transpose(1, 2) + v_bshd = v_cache.transpose(1, 2) + _flash_attn_fwd( + q_bshd, + k_bshd, + v_bshd, + seqused_k=cache_seqlens, + return_lse=True, + score_mod=cute_score_mod, + out=out_cute, + lse=None, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + ) + else: + _flash_attn_fwd( + q_bshd, + k_cache_paged, + v_cache_paged, + seqused_k=cache_seqlens, + page_table=page_table, + return_lse=True, + score_mod=cute_score_mod, + out=out_cute, + lse=None, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + ) + + out_cute = out_cute.transpose(1, 2) + + assert out_cute.shape == out_ref_fp32.shape == out_pt.shape + assert not torch.isnan(out_cute).any() + assert not torch.isnan(out_ref_fp32).any() + assert not torch.isnan(out_pt).any() + assert torch.isfinite(out_cute).all() + assert torch.isfinite(out_ref_fp32).all() + assert torch.isfinite(out_pt).all() + + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + rtol = 2 + + pt_error = (out_pt - out_ref_fp32).abs().max().item() + cute_error = (out_cute - out_ref_fp32).abs().max().item() + + print( + f"\nNumerical comparison for {cute_score_mod.__name__} (paged={page_size is not None}):" + ) + print(f" PyTorch vs FP32 ref max error: {pt_error:.2e}") + print(f" CuTE vs FP32 ref max error: {cute_error:.2e}") + print(f" Dynamic absolute tolerance: {fwd_atol:.2e}") + print(f" Error ratio (CuTE/PyTorch): {cute_error / max(pt_error, 1e-10):.2f}") + + assert cute_error <= rtol * pt_error + fwd_atol, ( + f"CuTE error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" + ) + + +@cute.jit +def score_mod_bwd_5(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + """Backward for score_mod_5 (times_two): d(score*2)/d(score) = 2.""" + return grad * cute.full_like(grad, 2.0) + + +@cute.jit +def score_mod_bwd_3(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + """Backward for score_mod_3 (relative_bias): d(score + |q-kv|)/d(score) = 1.""" + return grad + + +@cute.jit +def score_mod_bwd_identity(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + return grad + + +@cute.jit +def score_mod_bwd_causal(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + """Backward for causal masking: d(where(mask, score, -inf))/d(score) = where(mask, 1, 0). + + At unmasked positions (q_idx >= kv_idx), grad passes through. + At masked positions (q_idx < kv_idx), the kernel already zeros grad because P=0. + """ + return grad + + +@cute.jit +def score_mod_squared(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + """Forward: score ** 2.""" + return tSrS_ssa * tSrS_ssa + + +@cute.jit +def score_mod_bwd_squared(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + """Backward for score**2: d(score**2)/d(score) = 2*score.""" + return grad * cute.full_like(grad, 2.0) * score + + +def score_squared_eager(score, b, h, q_idx, kv_idx): + return score * score + + +BWD_TEST_PAIRS = [ + (score_mod_5, score_mod_bwd_5, times_two_eager), + (score_mod_3, score_mod_bwd_3, relative_bias_eager), + (score_mod_squared, score_mod_bwd_squared, score_squared_eager), + (score_mod_2, score_mod_bwd_causal, causal_mask_eager), +] + +BWD_TEST_PAIRS_WITH_AUX = [ + (score_mod_10, score_mod_bwd_identity, batch_bias), + (score_mod_11, score_mod_bwd_identity, dual_buffer_bias), +] + +BWD_TEST_PAIRS_PACK_GQA = [ + (score_mod_5, score_mod_bwd_5, times_two_eager), + (score_mod_3, score_mod_bwd_3, relative_bias_eager), +] + + +def run_cute_flash_bwd( + q, k, v, cute_score_mod, cute_score_mod_bwd, aux_tensors=None, pack_gqa=False +): + """Run flash attention forward + backward with score_mod.""" + q_t = q.transpose(1, 2) + k_t = k.transpose(1, 2) + v_t = v.transpose(1, 2) + + out, lse = _flash_attn_fwd( + q_t, k_t, v_t, + return_lse=True, + score_mod=cute_score_mod, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + ) + + grad_out = torch.randn_like(out) + + dq, dk, dv = _flash_attn_bwd( + q_t, k_t, v_t, + out, grad_out, lse, + score_mod=cute_score_mod, + score_mod_bwd=cute_score_mod_bwd, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + ) + + return ( + out.transpose(1, 2), + grad_out.transpose(1, 2), + dq.transpose(1, 2), + dk.transpose(1, 2), + dv.transpose(1, 2), + ) + + +def run_flex_reference_bwd(q, k, v, eager_score_mod, grad_out, dtype=None): + """Run flex_attention forward + backward for reference.""" + if dtype is not None: + q = q.to(dtype).requires_grad_(True) + k = k.to(dtype).requires_grad_(True) + v = v.to(dtype).requires_grad_(True) + grad_out = grad_out.to(dtype) + else: + q = q.requires_grad_(True) + k = k.requires_grad_(True) + v = v.requires_grad_(True) + + compiled_flex = torch.compile(flex_attention) + out = compiled_flex( + q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1] + ) + dq, dk, dv = torch.autograd.grad(out, (q, k, v), grad_out) + + return out, dq, dk, dv + + +@pytest.mark.parametrize( + "seqlen_q,seqlen_kv", + [ + (64, 64), + (128, 128), + (256, 256), + (512, 512), + (799, 3), + (3, 799), + (128, 256), + (256, 128), + (113, 203), + ], +) +@pytest.mark.parametrize("dim", [64, 128]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("score_mod_triple", BWD_TEST_PAIRS) +def test_cute_vs_flex_attention_backward(seqlen_q, seqlen_kv, dim, dtype, score_mod_triple): + """Test backward pass with score_mod against flex_attention reference.""" + if COMPUTE_CAPABILITY == 9 and dim == 64: + pytest.skip("head_dim=64 not supported on SM90 for backward") + + torch.random.manual_seed(42) + cute_fwd, cute_bwd, eager_ref = score_mod_triple + + q, k, v = create_tensors( + seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=4, dim=dim, dtype=dtype + ) + + out_cute, grad_out, dq_cute, dk_cute, dv_cute = run_cute_flash_bwd( + q, k, v, cute_fwd, cute_bwd + ) + out_ref_fp32, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd( + q, k, v, eager_ref, grad_out, dtype=torch.float32 + ) + out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd( + q, k, v, eager_ref, grad_out + ) + + assert not torch.isnan(dq_cute).any(), "dQ contains NaN" + assert not torch.isnan(dk_cute).any(), "dK contains NaN" + assert not torch.isnan(dv_cute).any(), "dV contains NaN" + + rtol = 2 + dq_atol = 2 * (dq_ref_fp32 + 0.3 - 0.3 - dq_ref_fp32).abs().max().item() + dk_atol = 2 * (dk_ref_fp32 + 0.3 - 0.3 - dk_ref_fp32).abs().max().item() + dv_atol = 2 * (dv_ref_fp32 + 0.3 - 0.3 - dv_ref_fp32).abs().max().item() + + dq_ref = dq_ref_fp32.to(dtype) + dk_ref = dk_ref_fp32.to(dtype) + dv_ref = dv_ref_fp32.to(dtype) + + pt_dq_err = (dq_pt - dq_ref).abs().max().item() + pt_dk_err = (dk_pt - dk_ref).abs().max().item() + pt_dv_err = (dv_pt - dv_ref).abs().max().item() + + cute_dq_err = (dq_cute - dq_ref).abs().max().item() + cute_dk_err = (dk_cute - dk_ref).abs().max().item() + cute_dv_err = (dv_cute - dv_ref).abs().max().item() + + print(f"\nBackward comparison for {cute_fwd.__name__}:") + print(f" dQ: PT err={pt_dq_err:.2e}, CuTE err={cute_dq_err:.2e}, atol={dq_atol:.2e}") + print(f" dK: PT err={pt_dk_err:.2e}, CuTE err={cute_dk_err:.2e}, atol={dk_atol:.2e}") + print(f" dV: PT err={pt_dv_err:.2e}, CuTE err={cute_dv_err:.2e}, atol={dv_atol:.2e}") + + assert cute_dq_err <= rtol * pt_dq_err + dq_atol, f"dQ error too large: {cute_dq_err:.2e}" + assert cute_dk_err <= rtol * pt_dk_err + dk_atol, f"dK error too large: {cute_dk_err:.2e}" + assert cute_dv_err <= rtol * pt_dv_err + dv_atol, f"dV error too large: {cute_dv_err:.2e}" + + +def make_aux_tensors_for_bwd(cute_score_mod, eager_factory, seqlen_q, num_heads, batch_size, dtype): + if cute_score_mod == score_mod_10: + buffer = torch.randn(batch_size, device="cuda", dtype=dtype) * 0.1 + return [buffer], eager_factory(buffer) + head_bias = torch.randn(num_heads, device="cuda", dtype=dtype) * 0.2 + pos_scale = torch.arange(seqlen_q, device="cuda", dtype=dtype) * 0.01 + return [head_bias, pos_scale], eager_factory(head_bias, pos_scale) + + +@pytest.mark.parametrize( + "seqlen_q,seqlen_kv", + [ + (64, 64), + (128, 128), + (256, 128), + ], +) +@pytest.mark.parametrize("dim", [64, 128]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("score_mod_triple", BWD_TEST_PAIRS_WITH_AUX) +def test_cute_vs_flex_attention_backward_with_aux( + seqlen_q, seqlen_kv, dim, dtype, score_mod_triple +): + if COMPUTE_CAPABILITY == 9 and dim == 64: + pytest.skip("head_dim=64 not supported on SM90 for backward") + + torch.random.manual_seed(42) + cute_fwd, cute_bwd, eager_factory = score_mod_triple + + q, k, v = create_tensors( + seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=4, dim=dim, dtype=dtype + ) + + aux_tensors, eager_ref = make_aux_tensors_for_bwd( + cute_fwd, eager_factory, seqlen_q, q.shape[1], q.shape[0], dtype + ) + + out_cute, grad_out, dq_cute, dk_cute, dv_cute = run_cute_flash_bwd( + q, k, v, cute_fwd, cute_bwd, aux_tensors=aux_tensors + ) + out_ref_fp32, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd( + q, k, v, eager_ref, grad_out, dtype=torch.float32 + ) + out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd( + q, k, v, eager_ref, grad_out + ) + + assert not torch.isnan(dq_cute).any() + assert not torch.isnan(dk_cute).any() + assert not torch.isnan(dv_cute).any() + + rtol = 3 + dq_atol = 2 * (dq_ref_fp32 + 0.3 - 0.3 - dq_ref_fp32).abs().max().item() + dk_atol = 2 * (dk_ref_fp32 + 0.3 - 0.3 - dk_ref_fp32).abs().max().item() + dv_atol = 2 * (dv_ref_fp32 + 0.3 - 0.3 - dv_ref_fp32).abs().max().item() + + dq_ref = dq_ref_fp32.to(dtype) + dk_ref = dk_ref_fp32.to(dtype) + dv_ref = dv_ref_fp32.to(dtype) + + pt_dq_err = (dq_pt - dq_ref).abs().max().item() + pt_dk_err = (dk_pt - dk_ref).abs().max().item() + pt_dv_err = (dv_pt - dv_ref).abs().max().item() + + cute_dq_err = (dq_cute - dq_ref).abs().max().item() + cute_dk_err = (dk_cute - dk_ref).abs().max().item() + cute_dv_err = (dv_cute - dv_ref).abs().max().item() + + print(f"\nBackward comparison with aux for {cute_fwd.__name__}:") + print(f" dQ: PT err={pt_dq_err:.2e}, CuTE err={cute_dq_err:.2e}, atol={dq_atol:.2e}") + print(f" dK: PT err={pt_dk_err:.2e}, CuTE err={cute_dk_err:.2e}, atol={dk_atol:.2e}") + print(f" dV: PT err={pt_dv_err:.2e}, CuTE err={cute_dv_err:.2e}, atol={dv_atol:.2e}") + + assert cute_dq_err <= rtol * pt_dq_err + dq_atol, f"dQ error too large: {cute_dq_err:.2e}" + assert cute_dk_err <= rtol * pt_dk_err + dk_atol, f"dK error too large: {cute_dk_err:.2e}" + assert cute_dv_err <= rtol * pt_dv_err + dv_atol, f"dV error too large: {cute_dv_err:.2e}" + + +@pytest.mark.parametrize("seqlen_q,seqlen_kv", [(128, 128), (128, 256)]) +@pytest.mark.parametrize("dim", [64, 128]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(4, 2)]) +@pytest.mark.parametrize("score_mod_triple", BWD_TEST_PAIRS_PACK_GQA) +def test_cute_vs_flex_attention_backward_pack_gqa( + seqlen_q, seqlen_kv, dim, dtype, qhead_per_kvhead, num_kv_heads, score_mod_triple +): + if COMPUTE_CAPABILITY == 9: + pytest.xfail("pack_gqa backward not yet implemented on SM90") + + torch.random.manual_seed(42) + cute_fwd, cute_bwd, eager_ref = score_mod_triple + + num_q_heads = num_kv_heads * qhead_per_kvhead + q, k, v = create_tensors( + seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=num_q_heads, dim=dim, dtype=dtype + ) + k = k[:, :num_kv_heads, :, :].clone() + v = v[:, :num_kv_heads, :, :].clone() + + out_cute, grad_out, dq_cute, dk_cute, dv_cute = run_cute_flash_bwd( + q, k, v, cute_fwd, cute_bwd, pack_gqa=True + ) + out_ref_fp32, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd( + q, k, v, eager_ref, grad_out, dtype=torch.float32 + ) + out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd( + q, k, v, eager_ref, grad_out + ) + + assert not torch.isnan(dq_cute).any() + assert not torch.isnan(dk_cute).any() + assert not torch.isnan(dv_cute).any() + + rtol = 3 + dq_atol = 2 * (dq_ref_fp32 + 0.3 - 0.3 - dq_ref_fp32).abs().max().item() + dk_atol = 2 * (dk_ref_fp32 + 0.3 - 0.3 - dk_ref_fp32).abs().max().item() + dv_atol = 2 * (dv_ref_fp32 + 0.3 - 0.3 - dv_ref_fp32).abs().max().item() + + dq_ref = dq_ref_fp32.to(dtype) + dk_ref = dk_ref_fp32.to(dtype) + dv_ref = dv_ref_fp32.to(dtype) + + pt_dq_err = (dq_pt - dq_ref).abs().max().item() + pt_dk_err = (dk_pt - dk_ref).abs().max().item() + pt_dv_err = (dv_pt - dv_ref).abs().max().item() + + cute_dq_err = (dq_cute - dq_ref).abs().max().item() + cute_dk_err = (dk_cute - dk_ref).abs().max().item() + cute_dv_err = (dv_cute - dv_ref).abs().max().item() + + print(f"\nBackward Pack-GQA comparison for {cute_fwd.__name__}:") + print(f" dQ: PT err={pt_dq_err:.2e}, CuTE err={cute_dq_err:.2e}, atol={dq_atol:.2e}") + print(f" dK: PT err={pt_dk_err:.2e}, CuTE err={cute_dk_err:.2e}, atol={dk_atol:.2e}") + print(f" dV: PT err={pt_dv_err:.2e}, CuTE err={cute_dv_err:.2e}, atol={dv_atol:.2e}") + + assert cute_dq_err <= rtol * pt_dq_err + dq_atol, f"dQ error too large: {cute_dq_err:.2e}" + assert cute_dk_err <= rtol * pt_dk_err + dk_atol, f"dK error too large: {cute_dk_err:.2e}" + assert cute_dv_err <= rtol * pt_dv_err + dv_atol, f"dV error too large: {cute_dv_err:.2e}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/cute/test_score_mod_varlen.py b/tests/cute/test_score_mod_varlen.py new file mode 100644 index 00000000000..7cca7f2aa0a --- /dev/null +++ b/tests/cute/test_score_mod_varlen.py @@ -0,0 +1,1056 @@ +import pytest +import torch +from torch.nn.attention.flex_attention import flex_attention +from flash_attn.cute.interface import _flash_attn_fwd +from test_score_mod import _generate_block_kvcache +from score_mod_definitions import ( + # TensorSSA-based score mods + score_mod_alibi, + score_mod_batch_bias, + score_mod_block_diagonal, + score_mod_causal, + score_mod_causal_v2, + score_mod_debug_global_idx, + score_mod_dual_buffer, + score_mod_global_kv_bias, + score_mod_global_logical_rel_plus_kv_bias, + score_mod_global_q_and_kv_bias, + score_mod_global_q_bias, + score_mod_global_rel_plus_kv_bias, + score_mod_identity, + score_mod_rel_bias, + score_mod_rel_bias_x2, + score_mod_sliding_window, + score_mod_stress_complex_arithmetic, + score_mod_stress_conditional_mask, + score_mod_stress_global_offset, + score_mod_stress_multi_buffer, + score_mod_stress_xor_pattern, + score_mod_times_two, +) # isort: split +from score_mod_definitions import ( + # Eager (torch) reference score mods + identity_eager, + causal_eager, + rel_bias_eager, + rel_bias_x2_eager, + times_two_eager, + alibi_eager, + sliding_window_eager, + block_diagonal_eager, + causal_v2_eager, + batch_bias_factory, + dual_buffer_factory, + packed_kv_bias_factory, + packed_q_bias_factory, + packed_rel_plus_kv_bias_factory, + packed_q_and_kv_bias_factory, + packed_logical_rel_plus_kv_bias_factory, + stress_complex_arithmetic_factory, + stress_conditional_mask_factory, + stress_multi_buffer_factory, + stress_global_offset_factory, + stress_xor_pattern_factory, + debug_global_idx_factory, +) + +IS_SM90 = torch.cuda.get_device_capability()[0] == 9 + +# ============================================================================= +# Test pairs +# ============================================================================= + +# (cute_score_mod, eager_factory_or_fn, aux_type) +# aux_type: None, "batch", "dual_buffer" +# All score_mods use 7-arg signature: (tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors) +TEST_PAIRS_NO_GLOBAL = [ + (score_mod_identity, identity_eager, None), + (score_mod_causal, causal_eager, None), + (score_mod_rel_bias, rel_bias_eager, None), + (score_mod_rel_bias_x2, rel_bias_x2_eager, None), + (score_mod_times_two, times_two_eager, None), + (score_mod_alibi, alibi_eager, None), + (score_mod_sliding_window, sliding_window_eager, None), + (score_mod_block_diagonal, block_diagonal_eager, None), + (score_mod_causal_v2, causal_v2_eager, None), + (score_mod_batch_bias, batch_bias_factory, "batch"), + (score_mod_dual_buffer, dual_buffer_factory, "dual_buffer"), +] + +# (cute_score_mod, eager_factory, aux_type, requires_global) +# aux_type: "kv", "q", "q_and_kv", "q_concat", "kv_with_cu", "multi_buffer" +# requires_global: "q" (needs varlen_q), "kv" (needs varlen_k), "both" (needs both) +# All score_mods use 7-arg signature and compute global indices from seqlen_info +TEST_PAIRS_WITH_GLOBAL = [ + (score_mod_global_kv_bias, packed_kv_bias_factory, "kv", "kv"), + (score_mod_global_q_bias, packed_q_bias_factory, "q", "q"), + (score_mod_global_rel_plus_kv_bias, packed_rel_plus_kv_bias_factory, "kv", "kv"), + (score_mod_global_q_and_kv_bias, packed_q_and_kv_bias_factory, "q_and_kv", "both"), + ( + score_mod_global_logical_rel_plus_kv_bias, + packed_logical_rel_plus_kv_bias_factory, + "kv", + "kv", + ), + ( + score_mod_stress_complex_arithmetic, + stress_complex_arithmetic_factory, + "q_concat", + "q", + ), + ( + score_mod_stress_conditional_mask, + stress_conditional_mask_factory, + "kv_with_cu", + "both", + ), + ( + score_mod_stress_multi_buffer, + stress_multi_buffer_factory, + "multi_buffer", + "both", + ), + (score_mod_stress_global_offset, stress_global_offset_factory, "kv", "kv"), + (score_mod_stress_xor_pattern, stress_xor_pattern_factory, "kv_with_cu", "kv"), + (score_mod_debug_global_idx, debug_global_idx_factory, "kv", "kv"), +] + +SEQLEN_CONFIGS = [ + ([1], [1]), + ([1, 1], [1, 1]), + ([2, 3], [2, 3]), + ([8, 16], [8, 16]), + ([32, 32], [32, 32]), + ([64, 128], [64, 128]), + ([64, 56, 128], [64, 56, 128]), + ([256, 512], [256, 512]), + ([113, 203], [113, 203]), + ([239, 1], [239, 1]), + ([64], [64]), + ([128, 128], [128, 128]), + ([32, 32, 32, 32], [32, 32, 32, 32]), + ([16, 32, 64, 128, 256], [16, 32, 64, 128, 256]), + ([1, 1024], [1, 1024]), + ([1024, 1], [1024, 1]), + ([1, 256, 1], [1, 256, 1]), + ([256, 1, 256], [256, 1, 256]), + ([17, 33, 65], [17, 33, 65]), + ([64, 128], [32, 64]), + ([100, 100], [50, 50]), + ([256, 512, 256], [128, 256, 128]), + ([2, 1], [16384, 32 * 1024]), + ([1, 1], [128 * 1024] * 2), + ([2, 1], [8192, 8192]), + ([1, 3], [8192, 8192]), + ([3, 3], [8192, 8192]), + ([128, 128], [8192, 8192]), + ([2, 2, 2], [8 * 1024] * 3), + ([2, 1], [1024 * 32, 16384]), + ([1, 2], [1024 * 32, 16384]), + ([1, 1, 1], [128 * 1024] * 3), + ([1, 1, 1], [256 * 1024] * 3), +] + +# ============================================================================= +# Helper functions +# ============================================================================= + + +def run_cute_flash( + q, + k, + v, + score_mod, + aux_tensors=None, + pack_gqa=False, + cu_seqlens_q=None, + cu_seqlens_k=None, + page_table=None, + seqused_k=None, +): + """Run CuTE flash attention.""" + if cu_seqlens_q is not None or cu_seqlens_k is not None: + out = torch.empty_like(q) + _flash_attn_fwd( + q, + k, + v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_k=seqused_k, + page_table=page_table, + return_lse=True, + score_mod=score_mod, + out=out, + lse=None, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + ) + return out + + out = torch.empty_like(q) + _flash_attn_fwd( + q, + k, + v, + seqused_k=seqused_k, + page_table=page_table, + return_lse=True, + score_mod=score_mod, + out=out, + lse=None, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + ) + return out + + +def run_flex_varlen_ref(q, k, v, cu_seqlens_q, cu_seqlens_k, score_mod, dtype=None): + """Run flex_attention per-sequence for varlen reference.""" + if cu_seqlens_q is not None: + num_batches = len(cu_seqlens_q) - 1 + else: + num_batches = len(cu_seqlens_k) - 1 + + results = [] + for i in range(num_batches): + # Get Q slice + if cu_seqlens_q is not None: + q_slice = ( + q[cu_seqlens_q[i] : cu_seqlens_q[i + 1]].unsqueeze(0).transpose(1, 2) + ) + else: + q_slice = q[i : i + 1].transpose(1, 2) + + # Get K/V slices + if cu_seqlens_k is not None: + k_slice = ( + k[cu_seqlens_k[i] : cu_seqlens_k[i + 1]].unsqueeze(0).transpose(1, 2) + ) + v_slice = ( + v[cu_seqlens_k[i] : cu_seqlens_k[i + 1]].unsqueeze(0).transpose(1, 2) + ) + else: + k_slice = k[i : i + 1].transpose(1, 2) + v_slice = v[i : i + 1].transpose(1, 2) + + if dtype is not None: + q_slice, k_slice, v_slice = ( + q_slice.to(dtype), + k_slice.to(dtype), + v_slice.to(dtype), + ) + + def wrapped_mod(score, b, h, q_idx, kv_idx): + return score_mod(score, i, h, q_idx, kv_idx) + + out = flex_attention( + q_slice, + k_slice, + v_slice, + score_mod=wrapped_mod, + enable_gqa=q_slice.shape[1] != k_slice.shape[1], + ) + results.append(out.transpose(1, 2).squeeze(0)) + + return torch.cat(results, dim=0) + + +def setup_tensors(seqlens_q, seqlens_k, varlen_q, varlen_k, num_heads, head_dim, dtype): + """Create Q, K, V tensors and cu_seqlens based on varlen flags.""" + batch_size = len(seqlens_q) + + if varlen_q: + total_q = sum(seqlens_q) + q = torch.randn(total_q, num_heads, head_dim, device="cuda", dtype=dtype) + cu_seqlens_q = torch.tensor( + [0] + list(torch.tensor(seqlens_q).cumsum(0).tolist()), + device="cuda", + dtype=torch.int32, + ) + else: + seqlen_q = seqlens_q[0] # All sequences have the same length for non-varlen + q = torch.randn( + batch_size, seqlen_q, num_heads, head_dim, device="cuda", dtype=dtype + ) + cu_seqlens_q = None + + if varlen_k: + total_k = sum(seqlens_k) + k = torch.randn(total_k, num_heads, head_dim, device="cuda", dtype=dtype) + v = torch.randn(total_k, num_heads, head_dim, device="cuda", dtype=dtype) + cu_seqlens_k = torch.tensor( + [0] + list(torch.tensor(seqlens_k).cumsum(0).tolist()), + device="cuda", + dtype=torch.int32, + ) + else: + seqlen_k = seqlens_k[0] # All sequences have the same length for non-varlen + k = torch.randn( + batch_size, seqlen_k, num_heads, head_dim, device="cuda", dtype=dtype + ) + v = torch.randn( + batch_size, seqlen_k, num_heads, head_dim, device="cuda", dtype=dtype + ) + cu_seqlens_k = None + + return q, k, v, cu_seqlens_q, cu_seqlens_k + + +def prepare_ref_tensors( + q, k, v, cu_seqlens_q, cu_seqlens_k, varlen_q, varlen_k, batch_size, seqlens_q +): + """Prepare tensors for flex_attention reference (handle mixed varlen formats).""" + num_heads = q.shape[1] if varlen_q else q.shape[2] + + if not varlen_q and varlen_k: + seqlen_q = q.shape[1] + q_packed = q.reshape(-1, num_heads, q.shape[-1]) + ref_cu_seqlens_q = torch.tensor( + [seqlen_q * i for i in range(batch_size + 1)], + device="cuda", + dtype=torch.int32, + ) + return q_packed, k, v, ref_cu_seqlens_q, cu_seqlens_k + + if varlen_q and not varlen_k: + return q, k, v, cu_seqlens_q, None + + return q, k, v, cu_seqlens_q, cu_seqlens_k + + +def check_results( + out_cute, + out_ref_fp32, + out_pt, + test_name, + rtol=2, + extra_atol=1e-4, + seqlens_q=None, + cu_seqlens_q=None, +): + """Compare CuTE output against references.""" + assert not torch.isnan(out_cute).any(), f"{test_name}: NaN in output" + assert torch.isfinite(out_cute).all(), f"{test_name}: Inf in output" + + varlen_q = cu_seqlens_q is not None + + if varlen_q: + # Unpack and compare per-sequence + assert seqlens_q is not None, "varlen_q requires use of seqlens_q" + num_seqs = len(seqlens_q) + max_cute_error = 0.0 + max_pt_error = 0.0 + + for i in range(num_seqs): + # Extract sequences using cu_seqlens (all outputs are in packed format) + start_q = cu_seqlens_q[i] + end_q = cu_seqlens_q[i + 1] + cute_seq = out_cute[start_q:end_q] + ref_seq = out_ref_fp32[start_q:end_q] + pt_seq = out_pt[start_q:end_q] + + max_cute_error = max( + max_cute_error, (cute_seq - ref_seq).abs().max().item() + ) + max_pt_error = max(max_pt_error, (pt_seq - ref_seq).abs().max().item()) + + cute_error = max_cute_error + pt_error = max_pt_error + else: + # Direct comparison + pt_error = (out_pt - out_ref_fp32).abs().max().item() + cute_error = (out_cute - out_ref_fp32).abs().max().item() + + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + + print(f"\n{test_name}:") + print(f" PyTorch vs FP32 ref: {pt_error:.2e}") + print(f" CuTE vs FP32 ref: {cute_error:.2e}") + + tol = rtol * pt_error + fwd_atol + extra_atol + assert cute_error <= tol, ( + f"{test_name}: CuTE error {cute_error:.2e} exceeds tolerance {tol:.2e}" + ) + + +# ============================================================================= +# Tests +# ============================================================================= + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("varlen_q", [True, False]) +@pytest.mark.parametrize("varlen_k", [True, False]) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(4, 2)]) +@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS) +@pytest.mark.parametrize("score_mod_tuple", TEST_PAIRS_NO_GLOBAL) +def test_varlen_with_score_mod( + seqlens_q, + seqlens_k, + varlen_q, + varlen_k, + qhead_per_kvhead, + num_kv_heads, + dtype, + score_mod_tuple, +): + """Test varlen attention with score_mod functions that don't use global indices. + + Covers: both varlen, varlen Q only, varlen K only. + Skips: neither varlen + """ + if not varlen_q and not varlen_k: + pytest.skip( + "At least one of varlen_q or varlen_k must be True for varlen tests" + ) + + # For non-varlen dimension, all sequences must have same length + if not varlen_q: + seqlens_q = [seqlens_q[0]] * len(seqlens_q) + if not varlen_k: + seqlens_k = [seqlens_k[0]] * len(seqlens_k) + + torch.random.manual_seed(42) + cute_score_mod, eager_factory, aux_type = score_mod_tuple + + num_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + head_dim = 128 + batch_size = len(seqlens_q) + + q, k, v, cu_seqlens_q, cu_seqlens_k = setup_tensors( + seqlens_q, seqlens_k, varlen_q, varlen_k, num_heads, head_dim, dtype + ) + + if pack_gqa: + if varlen_k: + k = k[:, :num_kv_heads, :].clone() + v = v[:, :num_kv_heads, :].clone() + else: + k = k[:, :, :num_kv_heads, :].clone() + v = v[:, :, :num_kv_heads, :].clone() + + aux_tensors = None + if aux_type == "batch": + bias = torch.zeros(batch_size, device="cuda", dtype=dtype) * 0.1 + aux_tensors = [bias] + eager_score_mod = eager_factory(bias) + elif aux_type == "dual_buffer": + seqlen_q = seqlens_q[0] if not varlen_q else max(seqlens_q) + head_bias = torch.randn(num_heads, device="cuda", dtype=dtype) * 0.2 + pos_bias = torch.arange(seqlen_q, device="cuda", dtype=dtype) * 0.01 + aux_tensors = [head_bias, pos_bias] + eager_score_mod = eager_factory(head_bias, pos_bias) + else: + eager_score_mod = eager_factory + + # Prepare reference tensors + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k = prepare_ref_tensors( + q, k, v, cu_seqlens_q, cu_seqlens_k, varlen_q, varlen_k, batch_size, seqlens_q + ) + + out_ref_fp32 = run_flex_varlen_ref( + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=torch.float32 + ) + out_pt = run_flex_varlen_ref( + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=dtype + ) + out_cute = run_cute_flash( + q, + k, + v, + cute_score_mod, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + + if not varlen_q and varlen_k: + seqlen_q = q.shape[1] + out_ref_fp32 = out_ref_fp32.reshape(batch_size, seqlen_q, num_heads, head_dim) + out_pt = out_pt.reshape(batch_size, seqlen_q, num_heads, head_dim) + + assert out_cute.shape == out_ref_fp32.shape, ( + f"Shape mismatch: {out_cute.shape} vs {out_ref_fp32.shape}" + ) + + test_name = f"{cute_score_mod.__name__} (varlen_q={varlen_q}, varlen_k={varlen_k})" + extra_atol = 2e-3 + check_results( + out_cute, + out_ref_fp32, + out_pt, + test_name, + extra_atol=extra_atol, + seqlens_q=seqlens_q if varlen_q else None, + cu_seqlens_q=cu_seqlens_q if varlen_q else None, + ) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("varlen_q", [True, False]) +@pytest.mark.parametrize("varlen_k", [True, False]) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) +@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS) +@pytest.mark.parametrize("score_mod_tuple", TEST_PAIRS_WITH_GLOBAL) +def test_varlen_with_global_idx_score_mod( + seqlens_q, + seqlens_k, + varlen_q, + varlen_k, + qhead_per_kvhead, + num_kv_heads, + dtype, + score_mod_tuple, +): + """Test varlen attention with score_mod functions that use global indices. + + These score_mods compute q_idx_global and/or kv_idx_global from seqlen_info for packed tensor indexing. + Skips tests where required global indices aren't available. + """ + if not varlen_q and not varlen_k: + pytest.skip( + "At least one of varlen_q or varlen_k must be True for varlen tests" + ) + + cute_score_mod, eager_factory, aux_type, requires_global = score_mod_tuple + + # Skip if score_mod requires global indices we can't provide + if requires_global == "q" and not varlen_q: + pytest.skip(f"{cute_score_mod.__name__} requires varlen_q for q_idx_global") + if requires_global == "kv" and not varlen_k: + pytest.skip(f"{cute_score_mod.__name__} requires varlen_k for kv_idx_global") + if requires_global == "both" and (not varlen_q or not varlen_k): + pytest.skip(f"{cute_score_mod.__name__} requires both varlen_q and varlen_k") + + # For non-varlen dimension, all sequences must have same length + if not varlen_q: + seqlens_q = [seqlens_q[0]] * len(seqlens_q) + if not varlen_k: + seqlens_k = [seqlens_k[0]] * len(seqlens_k) + + torch.random.manual_seed(42) + + num_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + head_dim = 128 + batch_size = len(seqlens_q) + max_rel_pos = 512 + + total_q = sum(seqlens_q) + total_k = sum(seqlens_k) + + cu_seqlens_q = torch.tensor( + [0] + list(torch.tensor(seqlens_q).cumsum(0).tolist()), + device="cuda", + dtype=torch.int32, + ) + cu_seqlens_k = torch.tensor( + [0] + list(torch.tensor(seqlens_k).cumsum(0).tolist()), + device="cuda", + dtype=torch.int32, + ) + + if varlen_q: + q = torch.randn(total_q, num_heads, head_dim, device="cuda", dtype=dtype) + else: + seqlen_q = seqlens_q[0] + q = torch.randn( + batch_size, seqlen_q, num_heads, head_dim, device="cuda", dtype=dtype + ) + + if varlen_k: + k = torch.randn(total_k, num_heads, head_dim, device="cuda", dtype=dtype) + v = torch.randn(total_k, num_heads, head_dim, device="cuda", dtype=dtype) + else: + seqlen_k = seqlens_k[0] + k = torch.randn( + batch_size, seqlen_k, num_heads, head_dim, device="cuda", dtype=dtype + ) + v = torch.randn( + batch_size, seqlen_k, num_heads, head_dim, device="cuda", dtype=dtype + ) + + if pack_gqa: + if varlen_k: + k = k[:, :num_kv_heads, :].clone() + v = v[:, :num_kv_heads, :].clone() + else: + k = k[:, :, :num_kv_heads, :].clone() + v = v[:, :, :num_kv_heads, :].clone() + + # Setup aux tensors based on indexing type + if aux_type == "kv": + bias = torch.randn(total_k, device="cuda", dtype=dtype) * 0.1 + aux_tensors = [bias] + eager_score_mod = eager_factory(bias, cu_seqlens_k) + elif aux_type == "q": + bias = torch.randn(total_q, device="cuda", dtype=dtype) * 0.1 + aux_tensors = [bias] + eager_score_mod = eager_factory(bias, cu_seqlens_q) + elif aux_type == "q_and_kv": + q_bias = torch.randn(total_q, device="cuda", dtype=dtype) * 0.1 + kv_bias = torch.randn(total_k, device="cuda", dtype=dtype) * 0.1 + aux_tensors = [q_bias, kv_bias] + eager_score_mod = eager_factory(q_bias, kv_bias, cu_seqlens_q, cu_seqlens_k) + elif aux_type == "q_concat": + bias = torch.randn(total_q, device="cuda", dtype=dtype) * 0.1 + aux_tensors = [bias] + eager_score_mod = eager_factory(bias, cu_seqlens_q) + elif aux_type == "kv_with_cu": + kv_bias = torch.randn(total_k, device="cuda", dtype=dtype) * 0.1 + aux_tensors = [kv_bias] + eager_score_mod = eager_factory(kv_bias, cu_seqlens_q, cu_seqlens_k) + elif aux_type == "multi_buffer": + batch_bias = torch.randn(batch_size, device="cuda", dtype=dtype) * 0.1 + head_scale = torch.randn(num_heads, device="cuda", dtype=dtype) * 0.1 + 1.0 + q_pos_bias = torch.randn(total_q, device="cuda", dtype=dtype) * 0.1 + kv_pos_bias = torch.randn(total_k, device="cuda", dtype=dtype) * 0.1 + rel_pos_scale = ( + torch.randn(max_rel_pos * 2 + 1, device="cuda", dtype=dtype) * 0.1 + ) + aux_tensors = [batch_bias, head_scale, q_pos_bias, kv_pos_bias, rel_pos_scale] + eager_score_mod = eager_factory( + batch_bias, + head_scale, + q_pos_bias, + kv_pos_bias, + rel_pos_scale, + cu_seqlens_q, + cu_seqlens_k, + max_rel_pos, + ) + else: + raise ValueError(f"Unknown aux_type: {aux_type}") + + # Prepare reference tensors for flex_attention + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k = prepare_ref_tensors( + q, k, v, cu_seqlens_q, cu_seqlens_k, varlen_q, varlen_k, batch_size, seqlens_q + ) + + out_ref_fp32 = run_flex_varlen_ref( + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=torch.float32 + ) + out_pt = run_flex_varlen_ref( + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=dtype + ) + + kernel_cu_seqlens_q = cu_seqlens_q if varlen_q else None + kernel_cu_seqlens_k = cu_seqlens_k if varlen_k else None + out_cute = run_cute_flash( + q, + k, + v, + cute_score_mod, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + cu_seqlens_q=kernel_cu_seqlens_q, + cu_seqlens_k=kernel_cu_seqlens_k, + ) + + if varlen_q: + out_ref_final = out_ref_fp32 + out_pt_final = out_pt + out_cute_final = out_cute + else: + seqlen_q = seqlens_q[0] + out_ref_final = out_ref_fp32.reshape(batch_size, seqlen_q, num_heads, head_dim) + out_pt_final = out_pt.reshape(batch_size, seqlen_q, num_heads, head_dim) + out_cute_final = out_cute + + assert out_cute_final.shape == out_ref_final.shape, ( + f"Shape mismatch: {out_cute_final.shape} vs {out_ref_final.shape}" + ) + + test_name = f"{cute_score_mod.__name__} (varlen_q={varlen_q}, varlen_k={varlen_k}, {aux_type})" + + check_results( + out_cute_final, + out_ref_final, + out_pt_final, + test_name, + extra_atol=1e-3, + seqlens_q=seqlens_q if varlen_q else None, + cu_seqlens_q=cu_seqlens_q if varlen_q else None, + ) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("page_size", [None, 128]) +@pytest.mark.parametrize("varlen_q", [True, False]) +@pytest.mark.parametrize("varlen_k", [True, False]) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(4, 2)]) +@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS) +@pytest.mark.parametrize("score_mod_tuple", TEST_PAIRS_NO_GLOBAL) +def test_varlen_score_mod_kvcache( + seqlens_q, + seqlens_k, + varlen_q, + varlen_k, + qhead_per_kvhead, + num_kv_heads, + page_size, + dtype, + score_mod_tuple, +): + """Test varlen attention with score_mod and paged KV cache.""" + if IS_SM90 and page_size is not None: + pytest.xfail("paged KV not supported on SM90") + + if not varlen_q and not varlen_k: + pytest.skip( + "At least one of varlen_q or varlen_k must be True for varlen tests" + ) + + if page_size is not None and varlen_k: + pytest.skip("Paged KV requires batched (non-varlen) K") + + if not varlen_q: + seqlens_q = [seqlens_q[0]] * len(seqlens_q) + if not varlen_k: + seqlens_k = [seqlens_k[0]] * len(seqlens_k) + + # Skip if page_size doesn't divide seqlens evenly (for simplicity) + if page_size is not None and not varlen_k: + if seqlens_k[0] % page_size != 0: + pytest.skip("page_size must divide seqlen_k") + + torch.random.manual_seed(42) + cute_score_mod, eager_factory, aux_type = score_mod_tuple + + num_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + head_dim = 128 + batch_size = len(seqlens_q) + device = "cuda" + + # Setup tensors + q, k, v, cu_seqlens_q, cu_seqlens_k = setup_tensors( + seqlens_q, seqlens_k, varlen_q, varlen_k, num_heads, head_dim, dtype + ) + + if pack_gqa: + if varlen_k: + k = k[:, :num_kv_heads, :].clone() + v = v[:, :num_kv_heads, :].clone() + else: + k = k[:, :, :num_kv_heads, :].clone() + v = v[:, :, :num_kv_heads, :].clone() + + page_table = None + k_cache_paged = None + v_cache_paged = None + k_cache = k + v_cache = v + + if page_size is not None: + seqlen_k = seqlens_k[0] + ( + k_cache_bhsd, + v_cache_bhsd, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, page_size, batch_size, num_kv_heads, head_dim, device, dtype + ) + k_cache = k_cache_bhsd.transpose(1, 2) # BHSD -> BSHD + v_cache = v_cache_bhsd.transpose(1, 2) + seqused_k = torch.tensor(seqlens_k, dtype=torch.int32, device=device) + else: + seqused_k = None + + # Setup aux tensors and eager score_mod + aux_tensors = None + if aux_type == "batch": + bias = torch.zeros(batch_size, device=device, dtype=dtype) * 0.1 + aux_tensors = [bias] + eager_score_mod = eager_factory(bias) + elif aux_type == "dual_buffer": + seqlen_q = seqlens_q[0] if not varlen_q else max(seqlens_q) + head_bias = torch.randn(num_heads, device=device, dtype=dtype) * 0.2 + pos_bias = torch.arange(seqlen_q, device=device, dtype=dtype) * 0.01 + aux_tensors = [head_bias, pos_bias] + eager_score_mod = eager_factory(head_bias, pos_bias) + else: + eager_score_mod = eager_factory + + # Prepare reference tensors + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k = prepare_ref_tensors( + q, + k_cache, + v_cache, + cu_seqlens_q, + cu_seqlens_k, + varlen_q, + varlen_k, + batch_size, + seqlens_q, + ) + + out_ref_fp32 = run_flex_varlen_ref( + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=torch.float32 + ) + out_pt = run_flex_varlen_ref( + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=dtype + ) + + k_input = k_cache_paged if page_size is not None else k_cache + v_input = v_cache_paged if page_size is not None else v_cache + + out_cute = run_cute_flash( + q, + k_input, + v_input, + cute_score_mod, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + cu_seqlens_q=cu_seqlens_q if varlen_q else None, + cu_seqlens_k=cu_seqlens_k if (varlen_k and page_size is None) else None, + page_table=page_table if page_size is not None else None, + seqused_k=seqused_k if page_size is not None else None, + ) + + if not varlen_q and varlen_k: + seqlen_q = q.shape[1] + out_ref_fp32 = out_ref_fp32.reshape(batch_size, seqlen_q, num_heads, head_dim) + out_pt = out_pt.reshape(batch_size, seqlen_q, num_heads, head_dim) + + assert out_cute.shape == out_ref_fp32.shape, ( + f"Shape mismatch: {out_cute.shape} vs {out_ref_fp32.shape}" + ) + + test_name = f"{cute_score_mod.__name__} (varlen_q={varlen_q}, varlen_k={varlen_k}, paged={page_size is not None})" + extra_atol = 2e-3 + check_results( + out_cute, + out_ref_fp32, + out_pt, + test_name, + extra_atol=extra_atol, + seqlens_q=seqlens_q if varlen_q else None, + cu_seqlens_q=cu_seqlens_q if varlen_q else None, + ) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("page_size", [None, 128]) +@pytest.mark.parametrize("varlen_q", [True, False]) +@pytest.mark.parametrize("varlen_k", [True, False]) +@pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) +@pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS) +@pytest.mark.parametrize("score_mod_tuple", TEST_PAIRS_WITH_GLOBAL) +def test_varlen_score_mod_with_paged_kvcache_global( + seqlens_q, + seqlens_k, + varlen_q, + varlen_k, + qhead_per_kvhead, + num_kv_heads, + page_size, + dtype, + score_mod_tuple, +): + """Test varlen attention with global idx score_mod and paged KV cache.""" + if IS_SM90 and page_size is not None: + pytest.xfail("paged KV not supported on SM90") + + if page_size is not None and varlen_k: + pytest.skip("Paged KV cache requires batched (non-varlen) K") + + if not varlen_q and not varlen_k: + pytest.skip( + "At least one of varlen_q or varlen_k must be True for varlen tests" + ) + + if not varlen_q: + seqlens_q = [seqlens_q[0]] * len(seqlens_q) + if not varlen_k: + seqlens_k = [seqlens_k[0]] * len(seqlens_k) + + if page_size is not None and not varlen_k: + if seqlens_k[0] % page_size != 0: + pytest.skip("page_size must divide seqlen_k") + + cute_score_mod, eager_factory, aux_type, requires_global = score_mod_tuple + + if requires_global == "q" and not varlen_q: + pytest.skip(f"{cute_score_mod.__name__} requires varlen_q for q_idx_global") + if requires_global == "kv" and not varlen_k: + pytest.skip(f"{cute_score_mod.__name__} requires varlen_k for kv_idx_global") + if requires_global == "both" and (not varlen_q or not varlen_k): + pytest.skip(f"{cute_score_mod.__name__} requires both varlen_q and varlen_k") + + torch.random.manual_seed(42) + + num_heads = num_kv_heads * qhead_per_kvhead + pack_gqa = qhead_per_kvhead > 1 + head_dim = 128 + batch_size = len(seqlens_q) + max_rel_pos = 512 + device = "cuda" + + total_q = sum(seqlens_q) + total_k = sum(seqlens_k) + + cu_seqlens_q = torch.tensor( + [0] + list(torch.tensor(seqlens_q).cumsum(0).tolist()), + device=device, + dtype=torch.int32, + ) + cu_seqlens_k = torch.tensor( + [0] + list(torch.tensor(seqlens_k).cumsum(0).tolist()), + device=device, + dtype=torch.int32, + ) + cu_seqlens_k_for_kernel = cu_seqlens_k if varlen_k else None + + q = torch.randn(total_q, num_heads, head_dim, device=device, dtype=dtype) + if varlen_k: + k = torch.randn(total_k, num_heads, head_dim, device=device, dtype=dtype) + v = torch.randn(total_k, num_heads, head_dim, device=device, dtype=dtype) + else: + seqlen_k = seqlens_k[0] + k = torch.randn( + batch_size, seqlen_k, num_heads, head_dim, device=device, dtype=dtype + ) + v = torch.randn( + batch_size, seqlen_k, num_heads, head_dim, device=device, dtype=dtype + ) + + if pack_gqa: + if varlen_k: + k = k[:, :num_kv_heads, :].clone() + v = v[:, :num_kv_heads, :].clone() + else: + k = k[:, :, :num_kv_heads, :].clone() + v = v[:, :, :num_kv_heads, :].clone() + + page_table = None + k_cache_paged = None + v_cache_paged = None + k_cache = k + v_cache = v + + if page_size is not None: + seqlen_k = seqlens_k[0] + ( + k_cache_bhsd, + v_cache_bhsd, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, page_size, batch_size, num_kv_heads, head_dim, device, dtype + ) + k_cache = k_cache_bhsd.transpose(1, 2) # BHSD -> BSHD + v_cache = v_cache_bhsd.transpose(1, 2) + seqused_k = torch.tensor(seqlens_k, dtype=torch.int32, device=device) + else: + seqused_k = None + + if aux_type == "kv": + bias = torch.randn(total_k, device=device, dtype=dtype) * 0.1 + aux_tensors = [bias] + eager_score_mod = eager_factory(bias, cu_seqlens_k) + elif aux_type == "q": + bias = torch.randn(total_q, device=device, dtype=dtype) * 0.1 + aux_tensors = [bias] + eager_score_mod = eager_factory(bias, cu_seqlens_q) + elif aux_type == "q_and_kv": + q_bias = torch.randn(total_q, device=device, dtype=dtype) * 0.1 + kv_bias = torch.randn(total_k, device=device, dtype=dtype) * 0.1 + aux_tensors = [q_bias, kv_bias] + eager_score_mod = eager_factory(q_bias, kv_bias, cu_seqlens_q, cu_seqlens_k) + elif aux_type == "q_concat": + bias = torch.randn(total_q, device=device, dtype=dtype) * 0.1 + aux_tensors = [bias] + eager_score_mod = eager_factory(bias, cu_seqlens_q) + elif aux_type == "kv_with_cu": + kv_bias = torch.randn(total_k, device=device, dtype=dtype) * 0.1 + aux_tensors = [kv_bias] + eager_score_mod = eager_factory(kv_bias, cu_seqlens_q, cu_seqlens_k) + elif aux_type == "multi_buffer": + batch_bias = torch.randn(batch_size, device=device, dtype=dtype) * 0.1 + head_scale = torch.randn(num_heads, device=device, dtype=dtype) * 0.1 + 1.0 + q_pos_bias = torch.randn(total_q, device=device, dtype=dtype) * 0.1 + kv_pos_bias = torch.randn(total_k, device=device, dtype=dtype) * 0.1 + rel_pos_scale = ( + torch.randn(max_rel_pos * 2 + 1, device=device, dtype=dtype) * 0.1 + ) + aux_tensors = [batch_bias, head_scale, q_pos_bias, kv_pos_bias, rel_pos_scale] + eager_score_mod = eager_factory( + batch_bias, + head_scale, + q_pos_bias, + kv_pos_bias, + rel_pos_scale, + cu_seqlens_q, + cu_seqlens_k, + max_rel_pos, + ) + else: + raise ValueError(f"Unknown aux_type: {aux_type}") + + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k = prepare_ref_tensors( + q, + k_cache, + v_cache, + cu_seqlens_q, + cu_seqlens_k, + True, + varlen_k, + batch_size, + seqlens_q, + ) + + out_ref_fp32 = run_flex_varlen_ref( + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=torch.float32 + ) + out_pt = run_flex_varlen_ref( + q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=dtype + ) + + # Run CuTE + k_input = k_cache_paged if page_size is not None else k_cache + v_input = v_cache_paged if page_size is not None else v_cache + + out_cute = torch.empty_like(q) + _flash_attn_fwd( + q, + k_input, + v_input, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k_for_kernel if page_size is None else None, + seqused_k=seqused_k if page_size is not None else None, + page_table=page_table, + return_lse=True, + score_mod=cute_score_mod, + out=out_cute, + lse=None, + aux_tensors=aux_tensors, + pack_gqa=pack_gqa, + ) + + assert out_cute.shape == out_ref_fp32.shape, ( + f"Shape mismatch: {out_cute.shape} vs {out_ref_fp32.shape}" + ) + + test_name = f"{cute_score_mod.__name__} (paged={page_size is not None}, {aux_type})" + check_results( + out_cute, + out_ref_fp32, + out_pt, + test_name, + extra_atol=1e-3, + seqlens_q=seqlens_q, + cu_seqlens_q=cu_seqlens_q, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/cute/test_utils.py b/tests/cute/test_utils.py new file mode 100644 index 00000000000..189eb86957d --- /dev/null +++ b/tests/cute/test_utils.py @@ -0,0 +1,213 @@ +"""Unit tests for flash_attn.cute.utils module.""" + +import functools + +from flash_attn.cute import utils as cute_utils +from flash_attn.cute.utils import hash_callable + + +class TestHashCallable: + """Tests for hash_callable function.""" + + def test_returns_cute_hash_when_set_on_function(self): + """hash_callable should return __cute_hash__ immediately when set on function.""" + + def my_func(): + pass + + my_func.__cute_hash__ = "precomputed-hash-123" + + result = hash_callable(my_func) + assert result == "precomputed-hash-123" + + def test_returns_cute_hash_from_wrapped_function(self): + """hash_callable should check __wrapped__ for __cute_hash__.""" + + def inner_func(): + pass + + inner_func.__cute_hash__ = "inner-hash-456" + + # Simulate a decorator that sets __wrapped__ + @functools.wraps(inner_func) + def wrapper_func(): + return inner_func() + + result = hash_callable(wrapper_func) + assert result == "inner-hash-456" + + def test_prefers_wrapper_cute_hash_over_wrapped(self): + """When both wrapper and wrapped have __cute_hash__, prefer wrapper.""" + + def inner_func(): + pass + + inner_func.__cute_hash__ = "inner-hash" + + @functools.wraps(inner_func) + def wrapper_func(): + return inner_func() + + wrapper_func.__cute_hash__ = "wrapper-hash" + + result = hash_callable(wrapper_func) + assert result == "wrapper-hash" + + def test_fallback_to_source_hashing(self): + """hash_callable should fall back to source hashing when no __cute_hash__.""" + + def my_func(): + return 42 + + result = hash_callable(my_func) + # Should return a hex string (SHA256 hash) + assert isinstance(result, str) + assert len(result) == 64 # SHA256 produces 64 hex chars + + def test_same_function_produces_same_hash(self): + """Same function should produce consistent hash.""" + + def my_func(): + return 42 + + hash1 = hash_callable(my_func) + hash2 = hash_callable(my_func) + assert hash1 == hash2 + + def test_different_functions_produce_different_hashes(self): + """Different functions should produce different hashes.""" + + def func_a(): + return 1 + + def func_b(): + return 2 + + hash_a = hash_callable(func_a) + hash_b = hash_callable(func_b) + assert hash_a != hash_b + + def test_fast_path_skips_expensive_hashing(self): + """When __cute_hash__ is set, expensive operations should be skipped.""" + + def my_func(): + pass + + my_func.__cute_hash__ = "fast-hash" + + # Mock at module level since we loaded it directly + original_getsource = cute_utils.inspect.getsource + call_tracker = {"getsource": 0, "sha256": 0} + + def tracking_getsource(*args, **kwargs): + call_tracker["getsource"] += 1 + return original_getsource(*args, **kwargs) + + original_sha256 = cute_utils.hashlib.sha256 + + def tracking_sha256(*args, **kwargs): + call_tracker["sha256"] += 1 + return original_sha256(*args, **kwargs) + + cute_utils.inspect.getsource = tracking_getsource + cute_utils.hashlib.sha256 = tracking_sha256 + try: + result = hash_callable(my_func) + finally: + cute_utils.inspect.getsource = original_getsource + cute_utils.hashlib.sha256 = original_sha256 + + # Neither inspect.getsource nor hashlib.sha256 should be called + assert call_tracker["getsource"] == 0, "getsource should not be called" + assert call_tracker["sha256"] == 0, "sha256 should not be called" + assert result == "fast-hash" + + def test_fast_path_on_wrapped_skips_expensive_hashing(self): + """When __cute_hash__ is on __wrapped__, expensive operations should be skipped.""" + + def inner_func(): + pass + + inner_func.__cute_hash__ = "wrapped-fast-hash" + + @functools.wraps(inner_func) + def wrapper_func(): + return inner_func() + + # Mock at module level + original_getsource = cute_utils.inspect.getsource + call_tracker = {"getsource": 0, "sha256": 0} + + def tracking_getsource(*args, **kwargs): + call_tracker["getsource"] += 1 + return original_getsource(*args, **kwargs) + + original_sha256 = cute_utils.hashlib.sha256 + + def tracking_sha256(*args, **kwargs): + call_tracker["sha256"] += 1 + return original_sha256(*args, **kwargs) + + cute_utils.inspect.getsource = tracking_getsource + cute_utils.hashlib.sha256 = tracking_sha256 + try: + result = hash_callable(wrapper_func) + finally: + cute_utils.inspect.getsource = original_getsource + cute_utils.hashlib.sha256 = original_sha256 + + assert call_tracker["getsource"] == 0, "getsource should not be called" + assert call_tracker["sha256"] == 0, "sha256 should not be called" + assert result == "wrapped-fast-hash" + + def test_closure_values_affect_hash(self): + """Functions with different closure values should have different hashes.""" + value1 = 10 + value2 = 20 + + def make_func(val): + def inner(): + return val + + return inner + + func1 = make_func(value1) + func2 = make_func(value2) + + hash1 = hash_callable(func1) + hash2 = hash_callable(func2) + assert hash1 != hash2 + + +class TestHashCallableIntegration: + """Integration tests for hash_callable with flash attention.""" + + def test_repeated_calls_use_cached_hash(self): + """Repeated calls with same score_mod should use cached/fast hash path.""" + + def score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): + return tSrS_ssa + + # Set __cute_hash__ to simulate Inductor-generated code + score_mod.__cute_hash__ = "inductor-generated-hash" + + original_getsource = cute_utils.inspect.getsource + call_count = [0] # Use list for mutable counter in nested function + + def counting_getsource(*args, **kwargs): + call_count[0] += 1 + return original_getsource(*args, **kwargs) + + cute_utils.inspect.getsource = counting_getsource + try: + # Call hash_callable multiple times + hash1 = hash_callable(score_mod) + hash2 = hash_callable(score_mod) + hash3 = hash_callable(score_mod) + finally: + cute_utils.inspect.getsource = original_getsource + + # getsource should never be called because __cute_hash__ is set + assert call_count[0] == 0, f"getsource was called {call_count[0]} times" + assert hash1 == hash2 == hash3 == "inductor-generated-hash" + diff --git a/tests/ops/triton/test_layer_norm.py b/tests/ops/triton/test_layer_norm.py index 3d92b6b3296..a91c2b09f1f 100644 --- a/tests/ops/triton/test_layer_norm.py +++ b/tests/ops/triton/test_layer_norm.py @@ -16,6 +16,8 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 +# @pytest.mark.parametrize("zero_centered_weight", [False, True]) +@pytest.mark.parametrize("zero_centered_weight", [False]) @pytest.mark.parametrize("has_weight1", [False, True]) # @pytest.mark.parametrize("has_weight1", [True]) @pytest.mark.parametrize("has_x1", [False, True]) @@ -54,6 +56,7 @@ def test_layer_norm( has_rowscale, has_x1, has_weight1, + zero_centered_weight, ): if has_rowscale and has_x1: pytest.skip("Not supported") @@ -145,6 +148,7 @@ def test_layer_norm( rowscale=rowscale, prenorm=prenorm, residual_in_fp32=residual_in_fp32, + zero_centered_weight=zero_centered_weight, is_rms_norm=is_rms_norm, return_dropout_mask=True, ) @@ -162,6 +166,7 @@ def test_layer_norm( dropout_p=dropout_p, rowscale=rowscale, prenorm=prenorm, + zero_centered_weight=zero_centered_weight, dropout_mask=dropout_mask, dropout_mask1=dropout_mask1, ) @@ -177,6 +182,7 @@ def test_layer_norm( dropout_p=dropout_p, rowscale=rowscale, prenorm=prenorm, + zero_centered_weight=zero_centered_weight, dropout_mask=dropout_mask, dropout_mask1=dropout_mask1, upcast=True, diff --git a/tests/test_flash_attn_ck.py b/tests/test_flash_attn_ck.py index 503b7bf01c3..d5590fcfc82 100644 --- a/tests/test_flash_attn_ck.py +++ b/tests/test_flash_attn_ck.py @@ -1399,7 +1399,7 @@ def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}") print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}") assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() - assert (q.grad - q_ref.grad).abs().max().item() <= 5 * ( + assert (q.grad - q_ref.grad).abs().max().item() <= 7 * ( q_pt.grad - q_ref.grad ).abs().max().item() + 1e-3 assert (k.grad - k_ref.grad).abs().max().item() <= 5 * ( diff --git a/tests/test_flash_attn_triton_amd.py b/tests/test_flash_attn_triton_amd.py old mode 100644 new mode 100755 index d64246f9505..a37326ee5d8 --- a/tests/test_flash_attn_triton_amd.py +++ b/tests/test_flash_attn_triton_amd.py @@ -1,6 +1,4 @@ import math -import os -import random import pytest import torch @@ -18,12 +16,7 @@ from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import _get_block_size_n from flash_attn.layers.rotary import apply_rotary_emb -from flash_attn.flash_attn_triton_amd.utils import DEBUG - -# Test ROCM Triton Backend -USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" -if USE_TRITON_ROCM: - random.seed(42) +from flash_attn.flash_attn_triton_amd.utils import USE_TRITON_ROCM, is_rdna MAX_HEADDIM_SM8x = 192 @@ -572,33 +565,26 @@ def get_dropout_fraction( return dropped.sum() / valid.sum() -# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -@pytest.mark.parametrize("dtype", [torch.float16]) -# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("dtype", ([torch.float16])) +# @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("deterministic", [False]) -# @pytest.mark.parametrize("alibi", [False, True]) -@pytest.mark.parametrize("alibi", [False]) -# @pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("alibi", [False, True]) +# @pytest.mark.parametrize("alibi", [False]) @pytest.mark.parametrize("local", [False]) -# @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128]) # @pytest.mark.parametrize("d", [32]) # @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048]) -# @pytest.mark.parametrize("seqlen", [128]) -# @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) -@pytest.mark.parametrize("dropout_p", [0.0]) +# @pytest.mark.parametrize("seqlen", [512]) +@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +# @pytest.mark.parametrize("dropout_p", [0.0]) def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype): - if USE_TRITON_ROCM: - if dropout_p != 0.0: - pytest.skip("Dropout not supported in AMD's Triton Backend yet") - - if local == True: - pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") - if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM device = "cuda" @@ -719,45 +705,35 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() if dropout_p > 0.0: - assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() + # assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) - if DEBUG: - print("dqkv:", dqkv, dqkv.shape) - print("dqkv_ref:", dqkv_ref, dqkv_ref.shape) - print("dqkv_pt:", dqkv_pt, dqkv_pt.shape) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() -# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -@pytest.mark.parametrize('dtype', [torch.float16]) -# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("dtype", ([torch.float16])) +# @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize("deterministic", [False]) -# @pytest.mark.parametrize("alibi", [False, True]) -@pytest.mark.parametrize("alibi", [False]) -# @pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("alibi", [False, True]) +# @pytest.mark.parametrize("alibi", [True]) @pytest.mark.parametrize("local", [False]) -# @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize('causal', [False]) +# @pytest.mark.parametrize("local", [True]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 257, 384, 512, 768, 1025, 2048]) # @pytest.mark.parametrize('seqlen', [128]) -# @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) -@pytest.mark.parametrize('dropout_p', [0.0]) +@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +# @pytest.mark.parametrize('dropout_p', [0.0]) def test_flash_attn_varlen_qkvpacked( seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype ): - if USE_TRITON_ROCM: - if dropout_p != 0.0: - pytest.skip("Dropout not supported in AMD's Triton Backend yet") - - if local == True: - pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM device = "cuda" @@ -877,7 +853,7 @@ def test_flash_attn_varlen_qkvpacked( assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() if dropout_p > 0.0: - assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() + # assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) @@ -888,21 +864,19 @@ def test_flash_attn_varlen_qkvpacked( # @pytest.mark.parametrize("kvpacked", [True, False]) @pytest.mark.parametrize("kvpacked", [False]) -# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +# @pytest.mark.parametrize("kvpacked", [False]) +@pytest.mark.parametrize("dtype", ([torch.float16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("dtype", [torch.float16]) -# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -@pytest.mark.parametrize("mha_type", ["mha"]) -# @pytest.mark.parametrize("deterministic", [False, True]) -# @pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("deterministic", [False]) -# @pytest.mark.parametrize("alibi", [False, True]) -@pytest.mark.parametrize("alibi", [False]) -# @pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("alibi", [False, True]) +# @pytest.mark.parametrize("alibi", [False]) @pytest.mark.parametrize("local", [False]) -# @pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) -@pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) @@ -925,22 +899,16 @@ def test_flash_attn_varlen_qkvpacked( ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) -# @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) -@pytest.mark.parametrize("dropout_p", [0.0]) -# @pytest.mark.parametrize("softcap", [0.0, 50.0]) +@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +# @pytest.mark.parametrize("dropout_p", [0.0]) @pytest.mark.parametrize("softcap", [0.0]) def test_flash_attn_output( seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap ): if USE_TRITON_ROCM: - if dropout_p != 0.0: - pytest.skip("Dropout not supported on AMD's Triton Backend yet") - - if softcap != 0.0: - pytest.skip("softcap not supported on AMD's Triton Backend yet") - - if local == True: - pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") + if causal: + if seqlen_q ==1024 and seqlen_k==1024 and d==160: + pytest.skip("This test with causal=True is flakey") if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -1002,10 +970,6 @@ def test_flash_attn_output( deterministic=deterministic, return_attn_probs=True, ) - if DEBUG: - print("out:", out, out.shape) - print("lse:", lse, lse.shape) - if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( S_dmask, @@ -1160,52 +1124,34 @@ def test_flash_attn_output( # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. - if DEBUG: - print("out:", out, out.shape) - print("out_ref:", out_ref, out_ref.shape) assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() if dropout_p > 0.0: - assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() + # assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): - if DEBUG: - print("dv:", dv, dv.shape) - print("dv_ref:", dv_ref, dv_ref.shape) - print("dv_pt:", dv_pt, dv_pt.shape) - assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() - - if DEBUG: - print("dk:", dk, dk.shape) - print("dk_ref:", dk_ref, dk_ref.shape) - print("dk_pt:", dk_pt, dk_pt.shape) - assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() - - if DEBUG: - print("dq:", dq, dq.shape) - print("dq_ref:", dq_ref, dq_ref.shape) - print("dq_pt:", dq_pt, dq_pt.shape) assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() - + assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() + assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() @pytest.mark.parametrize("kvpacked", [False]) # @pytest.mark.parametrize('kvpacked', [False]) -# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -@pytest.mark.parametrize('dtype', [torch.float16]) -# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -@pytest.mark.parametrize('mha_type', ["mha"]) -# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("dtype", ([torch.float16])) +# @pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize('mha_type', ["mqa"]) @pytest.mark.parametrize("deterministic", [False]) -# @pytest.mark.parametrize("alibi", [False, True]) -@pytest.mark.parametrize("alibi", [False]) -# @pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("alibi", [False, True]) +# @pytest.mark.parametrize("alibi", [True]) @pytest.mark.parametrize("local", [False]) -# @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize('causal', [False]) +# @pytest.mark.parametrize("local", [True]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize('causal', [True]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [160]) @@ -1226,23 +1172,14 @@ def test_flash_attn_output( ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) -# @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) -@pytest.mark.parametrize('dropout_p', [0.0]) -# @pytest.mark.parametrize("softcap", [0.0, 50.0]) +@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("softcap", [0.0]) def test_flash_attn_varlen_output( seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap ): if USE_TRITON_ROCM: - if dropout_p != 0.0: - pytest.skip("Dropout not supported in AMD's Triton Backend yet") - - if local == True: - pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") - - if softcap != 0.0: - pytest.skip("softcap not supported on AMD's Triton Backend yet") - + if seqlen_q == 1 and seqlen_k == 147 and kvpacked == True and dropout_p != 0.0: + pytest.skip("This config with dropout is flaky on AMD.") if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -1347,11 +1284,6 @@ def test_flash_attn_varlen_output( deterministic=deterministic, return_attn_probs=True, ) - if DEBUG: - print("out_unpad:", out_unpad, out_unpad.shape) - print("sm_lse:", sm_lse, sm_lse.shape) - - out = output_pad_fn(out_unpad) if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( @@ -1516,34 +1448,19 @@ def test_flash_attn_varlen_output( assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() if dropout_p > 0.0: - assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() + # assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): - if DEBUG: - print("dv:", dv, dv.shape) - print("dv_ref:", dv_ref, dv_ref.shape) - print("dv_pt:", dv_pt, dv_pt.shape) - assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() - - if DEBUG: - print("dk:", dk, dk.shape) - print("dk_ref:", dk_ref, dk_ref.shape) - print("dk_pt:", dk_pt, dk_pt.shape) - assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() - - if DEBUG: - print("dq:", dq, dq.shape) - print("dq_ref:", dq_ref, dq_ref.shape) - print("dq_pt:", dq_pt, dq_pt.shape) assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() + assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() + assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() -# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -@pytest.mark.parametrize("dtype", [torch.float16]) -# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("dtype", ([torch.float16])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -1571,6 +1488,10 @@ def test_flash_attn_varlen_output( ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): + if USE_TRITON_ROCM: + if is_rdna(): + if seqlen_q == 1 and seqlen_k == 239 and d == 256: + pytest.skip("This config doesnot work on RDNA Devices.") if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -1646,27 +1567,13 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 - if DEBUG: - print("dv:", dv, dv.shape) - print("dv_ref:", dv_ref, dv_ref.shape) - print("dv_pt:", dv_pt, dv_pt.shape) - assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 - - if DEBUG: - print("dk:", dk, dk.shape) - print("dk_ref:", dk_ref, dk_ref.shape) - print("dk_pt:", dk_pt, dk_pt.shape) + assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 + assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 - if DEBUG: - print("dq:", dq, dq.shape) - print("dq_ref:", dq_ref, dq_ref.shape) - print("dq_pt:", dq_pt, dq_pt.shape) - assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 -# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -@pytest.mark.parametrize("dtype", [torch.float16]) -# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("dtype", ([torch.float16])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -1834,6 +1741,136 @@ def test_flash_attn_varlen_causal( assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 +@pytest.mark.parametrize("dtype", ([torch.float16])) +# @pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("deterministic", [False]) +# @pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("alibi", [False, True]) +# @pytest.mark.parametrize("alibi", [True]) +@pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64]) +@pytest.mark.parametrize("swap_sq_sk", [False]) +# @pytest.mark.parametrize("swap_sq_sk", [False]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (3, 1024), + (1, 339), + (64, 800), + (3, 799), + (64, 2048), + (16, 20000), + (16, 100000), + (128, 128), + (256, 256), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +@pytest.mark.skip() +def test_flash_attn_splitkv( + seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype +): + if USE_TRITON_ROCM: + if seqlen_q == 1 and seqlen_k == 339 and swap_sq_sk == True: + pytest.skip("This config with is flaky on AMD.") + if swap_sq_sk: + seqlen_q, seqlen_k = seqlen_k, seqlen_q + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 1 + nheads = 12 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + if alibi: + alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 + attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) + else: + alibi_slopes, attn_bias = None, None + out, lse, _ = flash_attn_func( + q, + k, + v, + 0.0, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + out_ref, attn_ref = attention_ref( + q, k, v, None, None, attn_bias, 0.0, None, causal=causal, window_size=window_size + ) + out_pt, attn_pt = attention_ref( + q, + k, + v, + None, + None, + attn_bias, + 0.0, + None, + causal=causal, + window_size=window_size, + upcast=False, + reorder_ops=True, + ) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + g = torch.randn_like(out) + do_o = (g.float() * out.float()).sum(-1) + ( + dq, + dk, + dv, + ) = torch.autograd.grad(out, (q, k, v), g) + ( + dq_ref, + dk_ref, + dv_ref, + ) = torch.autograd.grad(out_ref, (q, k, v), g) + ( + dq_pt, + dk_pt, + dv_pt, + ) = torch.autograd.grad(out_pt, (q, k, v), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 + + mult = 2 if not alibi else 8 + assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4 + assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4 + assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4 + + # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("num_splits", [1, 0]) @@ -1850,12 +1887,10 @@ def test_flash_attn_varlen_causal( # @pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) -# @pytest.mark.parametrize("rotary_interleaved", [False, True]) -@pytest.mark.parametrize("rotary_interleaved", [False]) -# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) -@pytest.mark.parametrize("rotary_fraction", [0.0]) -# @pytest.mark.parametrize("paged_kv_block_size", [None, 256]) -# @pytest.mark.parametrize("paged_kv_block_size", [256, 512]) +@pytest.mark.parametrize("rotary_interleaved", [False, True]) +# @pytest.mark.parametrize("rotary_interleaved", [False]) +@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) +# @pytest.mark.parametrize("rotary_fraction", [0.0]) @pytest.mark.parametrize("paged_kv_block_size", [None]) # @pytest.mark.parametrize("has_leftpad", [False, True]) @pytest.mark.parametrize("has_leftpad", [False]) @@ -1901,18 +1936,6 @@ def test_flash_attn_kvcache( num_splits, dtype, ): - if USE_TRITON_ROCM: - if paged_kv_block_size is not None: - pytest.skip("paged attention not supported on AMD's Triton Backend yet") - - if local == True: - pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") - - if rotary_interleaved == True or rotary_fraction > 0.0: - pytest.skip("rotary embedding not supported on AMD's Triton Backend yet") - - if has_leftpad == True: - pytest.skip("cache_leftpad not supported on AMD's Triton Backend yet") if seqlen_q > seqlen_k and new_kv: pytest.skip() if not new_kv and rotary_fraction > 0.0: @@ -2157,3 +2180,366 @@ def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, )[:, :seqlen_k] return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks + +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize('causal', [True]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 239), + (239, 1), + (3, 799), + (799, 3), + (1024, 128), + (97, 97), + (128, 128), + (200, 200), + (256, 256), + (257, 257), + (384, 384), + (512, 512), + (768, 768), + (1024, 1024), + ], +) +@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +# @pytest.mark.parametrize("dropout_p", [0.0]) +@pytest.mark.skip() +def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype): + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger + nheads = 4 + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + torch.random.manual_seed(42) + out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) + g = torch.randn_like(out0) + if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + ( + dq0, + dk0, + dv0, + ) = torch.autograd.grad(out0, (q, k, v), g) + # Numerical error if we just do any arithmetic on dq + dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item() + + for i in range(250): + torch.random.manual_seed(42) + out, lse, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) + assert torch.equal(out, out0) + assert torch.equal(lse, lse0) + + if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + ( + dq, + dk, + dv, + ) = torch.autograd.grad(out, (q, k, v), g) + dq_equal = torch.allclose(dq, dq0, atol=dq_atol) + if not dq_equal: + print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}") + assert torch.equal(dv, dv0) + assert torch.equal(dk, dk0) + assert dq_equal + + +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize('causal', [False]) +@pytest.mark.parametrize("d", [16, 32, 64]) +# @pytest.mark.parametrize('d', [16]) +@pytest.mark.parametrize("seqlen", [1, 2, 5, 17, 128]) +# @pytest.mark.parametrize('seqlen', [2]) +@pytest.mark.skip() +def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): + """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, + in the case where seqlen % 128 != 0. + """ + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 2 + nheads = 5 + q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 5 + k, v = [ + torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 3 + for _ in range(2) + ] + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + out = flash_attn_func(q, k, v, causal=causal) + g = torch.randn_like(out) + out.backward(g) + q_pt = q.detach().clone().requires_grad_(True) + k_pt = k.detach().clone().requires_grad_(True) + v_pt = v.detach().clone().requires_grad_(True) + out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True) + out_pt.backward(g) + q_ref = q.detach().clone().requires_grad_(True) + k_ref = k.detach().clone().requires_grad_(True) + v_ref = v.detach().clone().requires_grad_(True) + out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal) + out_ref.backward(g) + print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}") + print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}") + print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}") + print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}") + print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}") + print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}") + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + assert (q.grad - q_ref.grad).abs().max().item() <= 5 * ( + q_pt.grad - q_ref.grad + ).abs().max().item() + 1e-3 + assert (k.grad - k_ref.grad).abs().max().item() <= 5 * ( + k_pt.grad - k_ref.grad + ).abs().max().item() + 1e-3 + assert (v.grad - v_ref.grad).abs().max().item() <= 5 * ( + v_pt.grad - v_ref.grad + ).abs().max().item() + 1e-3 + + +@pytest.mark.parametrize("dtype", ([torch.float16])) +# @pytest.mark.parametrize('dtype', [torch.bfloat16]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize('causal', [False]) +@pytest.mark.parametrize("d", [64, 128]) +# @pytest.mark.parametrize('d', [64]) +@pytest.mark.parametrize("seqlen", [97, 128, 200, 256]) +# @pytest.mark.parametrize('seqlen', [128]) +@pytest.mark.skip() +def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype): + """We previously had a bug where we were using the wrong strides of dout, which shows up + when dout is not contiguous. + """ + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 5 + nheads = 2 + q, k, v = [ + torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda", requires_grad=True) + for _ in range(3) + ] + out = rearrange(flash_attn_func(q, k, v, causal=causal), "b s ... -> s b ...") + # So g is not contiguous + g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device="cuda")[:, ::2] + out.backward(g) + q_pt = q.detach().clone().requires_grad_(True) + k_pt = k.detach().clone().requires_grad_(True) + v_pt = v.detach().clone().requires_grad_(True) + out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True) + out_pt = rearrange(out_pt, "b s ... -> s b ...") + out_pt.backward(g) + q_ref = q.detach().clone().requires_grad_(True) + k_ref = k.detach().clone().requires_grad_(True) + v_ref = v.detach().clone().requires_grad_(True) + out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal) + out_ref = rearrange(out_ref, "b s ... -> s b ...") + out_ref.backward(g) + print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}") + print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}") + print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}") + print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}") + print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}") + print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}") + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + assert (q.grad - q_ref.grad).abs().max().item() <= 2 * ( + q_pt.grad - q_ref.grad + ).abs().max().item() + assert (k.grad - k_ref.grad).abs().max().item() <= 2 * ( + k_pt.grad - k_ref.grad + ).abs().max().item() + assert (v.grad - v_ref.grad).abs().max().item() <= 2 * ( + v_pt.grad - v_ref.grad + ).abs().max().item() + + +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize('causal', [False]) +@pytest.mark.parametrize("d", [16, 32, 64]) +# @pytest.mark.parametrize('d', [16]) +@pytest.mark.skip() +def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): + """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, + in the case where seqlen % 128 != 0 or varlen. + """ + device = "cuda" + # set seed + torch.random.manual_seed(0) + nheads = 5 + q_cuseqlen = torch.tensor([0, 76, 110, 256], device=device, dtype=torch.int32) + k_cuseqlen = torch.tensor([0, 1, 2, 3], device=device, dtype=torch.int32) + Mq = 256 + Mk = 3 + + q = torch.randn([Mq, nheads, d], dtype=dtype, device=device) * 3 + k, v = [torch.randn([Mk, nheads, d], dtype=dtype, device=device) * 3 for _ in range(2)] + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + + out = flash_attn_varlen_func(q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal) + g = torch.randn_like(out) + out.backward(g) + + assert not q.grad.isnan().any() + assert not k.grad.isnan().any() + assert not v.grad.isnan().any() + + +@pytest.mark.parametrize("dtype", ([torch.float16])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("local", [True]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64]) +@pytest.mark.parametrize("swap_sq_sk", [False]) +# @pytest.mark.parametrize("swap_sq_sk", [False]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 239), + (3, 799), + (127, 512), + (127, 513), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (1023, 1024), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +@pytest.mark.skip() +def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype): + if ( + max(seqlen_q, seqlen_k) >= 2048 + and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 + ): + pytest.skip() # Reference implementation OOM + if swap_sq_sk: + seqlen_q, seqlen_k = seqlen_k, seqlen_q + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 4 + nheads = 9 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True) + + g = torch.randn_like(out) + dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) + for _ in range(50): + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) + assert torch.equal(dv, dv0) + assert torch.equal(dk, dk0) + assert torch.equal(dq, dq0) + + +@pytest.mark.parametrize("dtype", ([torch.float16])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("local", [True]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64]) +@pytest.mark.parametrize("swap_sq_sk", [False]) +# @pytest.mark.parametrize("swap_sq_sk", [True]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 239), + (3, 799), + (127, 512), + (127, 513), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (1023, 1024), + ], +) +# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) +@pytest.mark.skip() +def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype): + if ( + max(seqlen_q, seqlen_k) >= 2048 + and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 + ): + pytest.skip() # Reference implementation OOM + if swap_sq_sk: + seqlen_q, seqlen_k = seqlen_k, seqlen_q + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 2 + nheads = 9 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") + key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + out = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + 0.0, + causal=causal, + window_size=window_size, + deterministic=True, + ) + + g = torch.randn_like(out) + dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) + for _ in range(50): + dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) + assert torch.equal(dv, dv0) + assert torch.equal(dk, dk0) + assert torch.equal(dq, dq0) diff --git a/tests/test_rotary.py b/tests/test_rotary.py index 0676d329c6f..1b69cff224d 100644 --- a/tests/test_rotary.py +++ b/tests/test_rotary.py @@ -5,6 +5,9 @@ import torch import torch.nn.functional as F from einops import rearrange + +import triton + from flash_attn.layers.rotary import apply_rotary_emb, apply_rotary_emb_torch from flash_attn.layers.rotary import apply_rotary_emb_qkv_, apply_rotary_emb_kv_ from flash_attn.bert_padding import pad_input, unpad_input @@ -97,6 +100,8 @@ def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_t "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16]) ) # @pytest.mark.parametrize('dtype', ([torch.float16])) +@pytest.mark.parametrize("compiled", [False, True]) +# @pytest.mark.parametrize("compiled", [True]) @pytest.mark.parametrize("gqa", [False, True]) # @pytest.mark.parametrize("gqa", [False]) @pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor]) @@ -105,7 +110,9 @@ def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_t # @pytest.mark.parametrize('rotary_fraction', [1.0]) @pytest.mark.parametrize("interleaved", [False, True]) # @pytest.mark.parametrize('interleaved', [False]) -def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, gqa, dtype): +def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, gqa, compiled, dtype): + if compiled: # Don't fall back to eager just bc of recompilation + torch._dynamo.config.recompile_limit = 2 ** 31 rtol = 1e-3 batch_size = 32 nheads = 4 @@ -126,7 +133,8 @@ def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, gqa, qkv_pt = qkv.detach().clone().requires_grad_() cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype) seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device) - out = apply_rotary_emb_qkv_( + fn = apply_rotary_emb_qkv_ if not compiled else torch.compile(apply_rotary_emb_qkv_) + out = fn( qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, num_heads_q=None if not gqa else nheads ) @@ -271,7 +279,7 @@ def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_of def test_compilation_count(): - batch_size = 1 + nheads = 4 headdim = 128 device = "cuda" dtype = torch.float16 @@ -288,11 +296,17 @@ def count_compilations(*args, **kwargs): old_cache_func = JITFunction.cache_hook try: - rotary_kernel.cache.clear() + if hasattr(rotary_kernel, "cache"): + rotary_kernel.cache.clear() + else: # Triton 3.3 replaces cache with per-device device_caches + device = triton.runtime.driver.active.get_current_device() + # device_caches[device] returns a 4-tuple: (kernel_cache, target, backend, binder) + rotary_kernel.device_caches[device][0].clear() + JITFunction.cache_hook = count_compilations for seqlen in (128, 256): - for nheads in (4, 32): + for batch_size in (4, 32): x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device) x.requires_grad_() cos, sin = generate_cos_sin(seqlen, headdim, device, dtype) diff --git a/usage.md b/usage.md index 133bfbdb6b2..6cd23652415 100644 --- a/usage.md +++ b/usage.md @@ -1,8 +1,7 @@ # FlashAttention adoption We've been very happy to see FlashAttention being adopted by many organizations -and research labs to speed up their training / inference (within 6 months after -FlashAttention's release, at the time of writing). +and research labs to speed up their training / inference. This page contains a partial list of places where FlashAttention is being used. If you'd like to add links to your organization / product / codebase, please open a PR or email us. We'd very much like to hear from you! diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index a61d0f9613a..4654d74779c 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -103,6 +103,7 @@ def get_scheduler_metadata( max_seqlen_k_new=0, causal=False, window_size=(-1, -1), # -1 means infinite context window + attention_chunk=0, has_softcap=False, num_splits=0, # Can be tuned for speed pack_gqa=None, # Can be tuned for speed @@ -124,6 +125,7 @@ def get_scheduler_metadata( max_seqlen_k_new, causal, window_size[0], window_size[1], + attention_chunk, has_softcap, num_splits, pack_gqa, @@ -147,6 +149,7 @@ def flash_attn_varlen_func( softmax_scale=None, causal=False, window_size: Optional[List[int]] = None, + attention_chunk=0, softcap=0.0, # 0.0 means deactivated alibi_slopes=None, deterministic=False, @@ -204,6 +207,7 @@ def flash_attn_varlen_func( Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. + attention_chunk: int. If > 0, chunked attention size for FA3. softcap: float. Anything > 0 activates softcapping attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) @@ -248,6 +252,8 @@ def flash_attn_varlen_func( ) if s_aux is not None: raise NotImplementedError("FA2 does not support s_aux") + if attention_chunk: + raise NotImplementedError("FA2 does not support attention_chunk") if num_splits > 1: raise NotImplementedError("FA2 does not support num_splits > 1") out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd( @@ -294,6 +300,7 @@ def flash_attn_varlen_func( softmax_scale, causal, real_window_size[0], real_window_size[1], + attention_chunk, softcap, True, # rotary_interleaved scheduler_metadata, diff --git a/vllm_flash_attn/pyproject.toml b/vllm_flash_attn/pyproject.toml deleted file mode 100644 index 3201555763e..00000000000 --- a/vllm_flash_attn/pyproject.toml +++ /dev/null @@ -1,3 +0,0 @@ -[tool.black] -line-length = 100 -target-version = ['py38'] \ No newline at end of file