diff --git a/.github/workflows/generate-skip-doc-change.py b/.github.upstream/generate-skip-doc-change.py similarity index 100% rename from .github/workflows/generate-skip-doc-change.py rename to .github.upstream/generate-skip-doc-change.py diff --git a/.github/workflows/skip-doc-change.yml.j2 b/.github.upstream/skip-doc-change.yml.j2 similarity index 100% rename from .github/workflows/skip-doc-change.yml.j2 rename to .github.upstream/skip-doc-change.yml.j2 diff --git a/.github/workflows/cffconvert.yml b/.github.upstream/workflows/cffconvert.yml similarity index 100% rename from .github/workflows/cffconvert.yml rename to .github.upstream/workflows/cffconvert.yml diff --git a/.github/workflows/codeql.yml b/.github.upstream/workflows/codeql.yml similarity index 100% rename from .github/workflows/codeql.yml rename to .github.upstream/workflows/codeql.yml diff --git a/.github/workflows/generated_fake_win_gpu_ci.yml b/.github.upstream/workflows/generated_fake_win_gpu_ci.yml similarity index 100% rename from .github/workflows/generated_fake_win_gpu_ci.yml rename to .github.upstream/workflows/generated_fake_win_gpu_ci.yml diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github.upstream/workflows/gradle-wrapper-validation.yml similarity index 100% rename from .github/workflows/gradle-wrapper-validation.yml rename to .github.upstream/workflows/gradle-wrapper-validation.yml diff --git a/.github/workflows/labeler.yml b/.github.upstream/workflows/labeler.yml similarity index 100% rename from .github/workflows/labeler.yml rename to .github.upstream/workflows/labeler.yml diff --git a/.github/workflows/lint.yml b/.github.upstream/workflows/lint.yml similarity index 100% rename from .github/workflows/lint.yml rename to .github.upstream/workflows/lint.yml diff --git a/.github/workflows/linux.yml b/.github.upstream/workflows/linux.yml similarity index 100% rename from .github/workflows/linux.yml rename to .github.upstream/workflows/linux.yml diff --git a/.github/workflows/publish-c-apidocs.yml b/.github.upstream/workflows/publish-c-apidocs.yml similarity index 100% rename from .github/workflows/publish-c-apidocs.yml rename to .github.upstream/workflows/publish-c-apidocs.yml diff --git a/.github/workflows/publish-csharp-apidocs.yml b/.github.upstream/workflows/publish-csharp-apidocs.yml similarity index 100% rename from .github/workflows/publish-csharp-apidocs.yml rename to .github.upstream/workflows/publish-csharp-apidocs.yml diff --git a/.github/workflows/publish-java-apidocs.yml b/.github.upstream/workflows/publish-java-apidocs.yml similarity index 100% rename from .github/workflows/publish-java-apidocs.yml rename to .github.upstream/workflows/publish-java-apidocs.yml diff --git a/.github/workflows/publish-python-apidocs.yml b/.github.upstream/workflows/publish-python-apidocs.yml similarity index 100% rename from .github/workflows/publish-python-apidocs.yml rename to .github.upstream/workflows/publish-python-apidocs.yml diff --git a/.github/workflows/windows.yml b/.github.upstream/workflows/windows.yml similarity index 100% rename from .github/workflows/windows.yml rename to .github.upstream/workflows/windows.yml diff --git a/.github/workflows/sca.yml b/.github/workflows/sca.yml deleted file mode 100644 index 1416f5a4d3..0000000000 --- a/.github/workflows/sca.yml +++ /dev/null @@ -1,133 +0,0 @@ -name: Windows_SCA -on: - push: - branches: - - main - - rel-* - pull_request: - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -env: - AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 - -jobs: - Onnxruntime-SCA-training-CUDA: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] - steps: - - uses: actions/checkout@v3 - with: - submodules: false - - uses: actions/setup-python@v4 - with: - python-version: '3.11.x' - architecture: 'x64' - - - uses: actions/setup-node@v3 - with: - node-version: 18 - - - name: Download cuda - run: azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v11.8" cuda_sdk - - - - name: Delete build folder - run: | - if (Test-Path D:\b) { Remove-Item -Recurse -Force D:\b } - &tools\ci_build\github\windows\install_third_party_deps.ps1 -cpu_arch x64 -install_prefix D:\b\Debug\installed -build_config Debug - - # The build machine doesn't have a GPU. So the value of CMAKE_CUDA_ARCHITECTURES doesn't matter. - - name: Build code - env: - CAExcludePath: 'C:\Program Files;D:\b;${{ github.workspace }}\cmake' - run: python tools\ci_build\build.py --windows_sdk_version 10.0.22621.0 --enable_training --build_java --compile_no_warning_as_error --config Debug --build_dir D:\b --skip_submodule_sync --build_csharp --update --build --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_pybind --cmake_extra_defines onnxruntime_USE_CUSTOM_STATIC_ANALYSIS_RULES=ON --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON --cmake_extra_defines onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE=ON --use_cuda --cuda_home=${{ github.workspace }}\cuda_sdk\v11.8 --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75 - - - name: Generate sarif - working-directory: D:\b - run: npx @microsoft/sarif-multitool merge *.sarif --recurse --output-directory=${{ github.workspace }}\output --output-file=MergeResult.sarif --merge-runs && dir ${{ github.workspace }}\output - - - name: Upload SARIF to GitHub - uses: github/codeql-action/upload-sarif@v2 - continue-on-error: true - with: - sarif_file: ${{ github.workspace }}\output\MergeResult.sarif - category: VS_SCA - - # No python - Onnxruntime-SCA-win32-WINML-x64: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] - steps: - - uses: actions/checkout@v3 - with: - submodules: false - - uses: actions/setup-python@v4 - with: - python-version: '3.11.x' - architecture: 'x64' - - - uses: actions/setup-node@v3 - with: - node-version: 18 - - - name: Delete build folder - run: | - if (Test-Path D:\b) { Remove-Item -Recurse -Force D:\b } - &tools\ci_build\github\windows\install_third_party_deps.ps1 -cpu_arch x64 -install_prefix D:\b\Debug\installed -build_config Debug - - # The build machine doesn't have a GPU. So the value of CMAKE_CUDA_ARCHITECTURES doesn't matter. - - name: Build code - env: - CAExcludePath: 'C:\Program Files;D:\b;${{ github.workspace }}\cmake' - run: python tools\ci_build\build.py --build_java --compile_no_warning_as_error --config Debug --build_dir D:\b --skip_submodule_sync --build_csharp --update --build --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --cmake_extra_defines onnxruntime_USE_CUSTOM_STATIC_ANALYSIS_RULES=ON --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON --cmake_extra_defines onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE=ON --ms_experimental --use_dml --use_winml --disable_rtti --enable_wcos --build_shared_lib - - - name: Generate sarif - working-directory: D:\b - run: npx @microsoft/sarif-multitool merge *.sarif --recurse --output-directory=${{ github.workspace }}\output --output-file=MergeResult.sarif --merge-runs && dir ${{ github.workspace }}\output - - - name: Upload SARIF to GitHub - uses: github/codeql-action/upload-sarif@v2 - continue-on-error: true - with: - sarif_file: ${{ github.workspace }}\output\MergeResult.sarif - category: VS_SCA_WIN32_WINML_X64 - - # No java, No python - Onnxruntime-SCA-win32-WINML-x86: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] - steps: - - uses: actions/checkout@v3 - with: - submodules: false - - uses: actions/setup-python@v4 - with: - python-version: '3.11.x' - architecture: 'x86' - - - uses: actions/setup-node@v3 - with: - node-version: 18 - - - name: Delete build folder - run: | - if (Test-Path D:\b) { Remove-Item -Recurse -Force D:\b } - &tools\ci_build\github\windows\install_third_party_deps.ps1 -cpu_arch x86 -install_prefix D:\b\Debug\installed -build_config Debug - - # The build machine doesn't have a GPU. So the value of CMAKE_CUDA_ARCHITECTURES doesn't matter. - - name: Build code - env: - CAExcludePath: 'C:\Program Files;D:\b;${{ github.workspace }}\cmake' - run: python tools\ci_build\build.py --compile_no_warning_as_error --config Debug --build_dir D:\b --skip_submodule_sync --build_csharp --update --build --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --cmake_extra_defines onnxruntime_USE_CUSTOM_STATIC_ANALYSIS_RULES=ON --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON --cmake_extra_defines onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE=ON --ms_experimental --use_dml --use_winml --disable_rtti --enable_wcos --build_shared_lib - - - name: Generate sarif - working-directory: D:\b - run: npx @microsoft/sarif-multitool merge *.sarif --recurse --output-directory=${{ github.workspace }}\output --output-file=MergeResult.sarif --merge-runs && dir ${{ github.workspace }}\output - - - name: Upload SARIF to GitHub - uses: github/codeql-action/upload-sarif@v2 - continue-on-error: true - with: - sarif_file: ${{ github.workspace }}\output\MergeResult.sarif - category: VS_SCA_WIN32_WINML_X86 diff --git a/.github/workflows/wheel.yaml b/.github/workflows/wheel.yaml new file mode 100644 index 0000000000..9537267ea8 --- /dev/null +++ b/.github/workflows/wheel.yaml @@ -0,0 +1,97 @@ +name: CI && Release & Upload Wheel + +on: + workflow_call: + inputs: + onnxruntime_branch: + type: string + default: "main" + workflow_dispatch: + inputs: + onnxruntime_branch: + type: string + default: "main" + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + build_and_upload_wheel_linux: + runs-on: The_CTOs_Choice + container: + image: ghcr.io/quadric-io/tvm:devel + options: "--mount type=bind,source=${{ github.workspace }},target=/workspace" + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + repository: quadric-io/onnxruntime + ref: ${{ inputs.onnxruntime_branch || github.ref }} + - name: Build ONNX Runtime wheel + working-directory: /workspace + run: | + python3 -m pip install cmake --upgrade + ./build.sh --build_wheel --config Release --parallel ${{ github.event_name == 'pull_request' && ' ' || '--skip_tests'}} --skip_submodule_sync --allow_running_as_root --compile_no_warning_as_error + wheel_path=$(find . -name '*.whl' | xargs readlink -f) + echo "wheel_path=$wheel_path" >> $GITHUB_ENV + - name: Upload Artifact + uses: actions/upload-artifact@v3 + with: + name: ort-wheel-linux + path: ${{ env.wheel_path }} + + build_and_upload_wheel_mac: + runs-on: [self-hosted, macOS, ARM64] + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + repository: quadric-io/onnxruntime + ref: ${{ inputs.onnxruntime_branch || github.ref }} + - name: Build ONNX Runtime wheel + run: | + ./build.sh --build_wheel --config Release --parallel ${{ github.event_name == 'pull_request' && ' ' || '--skip_tests'}} --skip_submodule_sync --compile_no_warning_as_error --skip_submodule_sync --apple_deploy_target 12 + wheel_path=$(find . -name '*.whl' | xargs readlink -f) + echo "wheel_path=$wheel_path" >> $GITHUB_ENV + - name: Upload Artifact + uses: actions/upload-artifact@v3 + with: + name: ort-wheel-mac + path: ${{ env.wheel_path }} + + create_release: + if: (github.ref == 'refs/heads/main') && ( github.event_name != 'workflow_call' && github.event_name != 'workflow_dispatch' ) + needs: [build_and_upload_wheel_mac, build_and_upload_wheel_linux] + runs-on: ubuntu-latest + steps: + - name: Download ort-wheel-linux artifact + uses: actions/download-artifact@v3 + with: + name: ort-wheel-linux + path: artifacts/ + - name: Download ort-wheel-mac artifact + uses: actions/download-artifact@v3 + with: + name: ort-wheel-mac + path: artifacts/ + - name: Count releases + id: count_releases + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + count=$(curl --request GET \ + --url https://api.github.com/repos/${{ github.repository }}/releases \ + --header "Authorization: Bearer $GITHUB_TOKEN" | jq length) + echo "count=$count" >> $GITHUB_ENV + - name: Create Release and Upload Both Assets + uses: softprops/action-gh-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: v${{ env.count }} + name: Release v${{ env.count }} + files: | + artifacts/*.whl diff --git a/.gitmodules b/.gitmodules index 036a248070..7bb49e98bf 100644 --- a/.gitmodules +++ b/.gitmodules @@ -8,6 +8,3 @@ path = cmake/external/emsdk url = https://github.com/emscripten-core/emsdk.git branch = 3.1.44 -[submodule "cmake/external/onnxruntime-extensions"] - path = cmake/external/onnxruntime-extensions - url = https://github.com/microsoft/onnxruntime-extensions.git diff --git a/README_EPU.md b/README_EPU.md new file mode 100644 index 0000000000..5ad295dd36 --- /dev/null +++ b/README_EPU.md @@ -0,0 +1,29 @@ +# The Quadric Version of onnxruntime + +This repository contains the a distribution of onnxruntime with additional operator quantization capabilities. + + +## Prerequisites: +- python 3.9 +- pip + +## Clone repository and build: +``` +git clone --recursive https://github.com/quadric-io/onnxruntime onnxruntime +cd onnxruntime +python3.9 -m venv venv +source venv/bin/activate +# Install required packages. numpy version is restricted by TVM +pip3 install wheel packaging numpy==1.24.4 +# Build the python package +./build.sh --build_wheel --config Release --parallel +``` + +## Install +``` +# Find the wheel you just created +$ find . -name '*.whl' +./build/MacOS/Release/dist/onnxruntime-1.16.0-cp39-cp39-macosx_13_0_arm64.whl +# Install it +pip3 install ./build/MacOS/Release/dist/onnxruntime-1.16.0-cp39-cp39-macosx_13_0_arm64.whl +``` diff --git a/ThirdPartyNotices.txt b/ThirdPartyNotices.txt index 6f6faa3a2e..985eb64566 100644 --- a/ThirdPartyNotices.txt +++ b/ThirdPartyNotices.txt @@ -6230,3 +6230,37 @@ https://github.com/intel/neural-compressor terms, and open source software license terms. These separate license terms govern your use of the third party programs as set forth in the "THIRD-PARTY-PROGRAMS" file. + +_____ + +FlashAttention, https://github.com/Dao-AILab/flash-attention + +BSD 3-Clause License + +Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/VERSION_NUMBER b/VERSION_NUMBER index 15b989e398..4a02d2c317 100644 --- a/VERSION_NUMBER +++ b/VERSION_NUMBER @@ -1 +1 @@ -1.16.0 +1.16.2 diff --git a/cgmanifests/cgmanifest.json b/cgmanifests/cgmanifest.json index 2a3de3bb0e..e8dbc9cf9e 100644 --- a/cgmanifests/cgmanifest.json +++ b/cgmanifests/cgmanifest.json @@ -568,7 +568,7 @@ "component": { "type": "git", "git": { - "commitHash": "d10b27fe37736d2944630ecd7557cefa95cf87c9", + "commitHash": "e7248b26a1ed53fa030c5c459f7ea095dfd276ac", "repositoryUrl": "https://gitlab.com/libeigen/eigen.git" } } diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index b01ed00350..82a454791d 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -84,7 +84,8 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF) option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF) -option(onnxruntime_USE_FLASH_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) +cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF) +option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF) option(onnxruntime_USE_AVX "Use AVX instructions" OFF) @@ -666,13 +667,16 @@ if (onnxruntime_USE_CUDA) if (onnxruntime_DISABLE_CONTRIB_OPS) set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6) message( STATUS "Turn off flash attention since CUDA compiler version < 11.6") set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() else() set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() if (onnxruntime_USE_CUDA) @@ -685,6 +689,11 @@ if (onnxruntime_USE_CUDA) list(APPEND ORT_PROVIDER_FLAGS -DUSE_FLASH_ATTENTION=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_FLASH_ATTENTION=1) endif() + if (onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) + message( STATUS "Enable memory efficient attention for CUDA EP") + list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1) + list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_MEMORY_EFFICIENT_ATTENTION=1) + endif() endif() if (onnxruntime_USE_VITISAI) diff --git a/cmake/deps.txt b/cmake/deps.txt index 1e685511b7..2965c60277 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -11,6 +11,7 @@ abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20220623.1.zip cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0 date;https://github.com/HowardHinnant/date/archive/refs/tags/v2.4.1.zip;ea99f021262b1d804a872735c658860a6a13cc98 dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b31321e5549591d78aa7f377173445 +eigen;https://gitlab.com/libeigen/eigen/-/archive/e7248b26a1ed53fa030c5c459f7ea095dfd276ac/eigen-e7248b26a1ed53fa030c5c459f7ea095dfd276ac.zip;be8be39fdbc6e60e94fa7870b280707069b5b81a flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v1.12.0.zip;ba0a75fd12dbef8f6557a74e611b7a3d0c5fe7bf fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip;b985f6985a05a1c03ff1bb71190f66d8f98a1494 fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1 @@ -41,5 +42,4 @@ re2;https://github.com/google/re2/archive/refs/tags/2022-06-01.zip;aa77313b76e91 safeint;https://github.com/dcleblanc/SafeInt/archive/ff15c6ada150a5018c5ef2172401cb4529eac9c0.zip;913a4046e5274d329af2806cb53194f617d8c0ab tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381 cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.0.0.zip;0f95b3c1fc1bd1175c4a90b2c9e39074d1bccefd -extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c -eigen;https://gitlab.com/libeigen/eigen/-/archive/3.4/eigen-3.4.zip;ee201b07085203ea7bd8eb97cbcb31b07cfa3efb +extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c \ No newline at end of file diff --git a/cmake/deps_update_and_upload.py b/cmake/deps_update_and_upload.py new file mode 100644 index 0000000000..d357284d91 --- /dev/null +++ b/cmake/deps_update_and_upload.py @@ -0,0 +1,56 @@ +# in case deps.txt is updated, run this file to update and upload the dependencies so that CI can use them. +# Before running the script, increase the version number found at: +# https://aiinfra.visualstudio.com/Lotus/_artifacts/feed/Lotus/UPack/onnxruntime_build_dependencies/versions +# Run without --do-upload once to verify downloading. Use --do-upload when you are ready to publish. +# python cmake/deps_update_and_upload.py --root-path C:/temp/onnxruntime_deps --version 1.0.82 --do-upload +# update version number in tools\ci_build\github\azure-pipelines\templates\download-deps.yml +import re +import subprocess +import os +import argparse +import tempfile + +parser = argparse.ArgumentParser(description="Update dependencies and publish to Azure Artifacts") +parser.add_argument( + "--root-path", type=str, default=tempfile.gettempdir(), help="Target root path for downloaded files" +) +parser.add_argument("--version", type=str, default="1.0.82", help="Package version to publish") +parser.add_argument("--do-upload", action="store_true", help="Upload the package to Azure Artifacts") +args = parser.parse_args() + +with open("cmake/deps.txt") as file: + text = file.read() + +lines = [line for line in text.split("\n") if not line.startswith("#") and ";" in line] + +root_path = args.root_path + +for line in lines: + url = re.sub("^[^;]+?;https://([^;]+?);.*", r"https://\1", line) + filename = re.sub("^[^;]+?;https://([^;]+?);.*", r"\1", line) + full_path = os.path.join(root_path, filename) + subprocess.run(["curl", "-sSL", "--create-dirs", "-o", full_path, url]) # noqa: PLW1510 + +package_name = "onnxruntime_build_dependencies" +version = args.version + +# Check if the user is logged in to Azure +result = subprocess.run("az account show", shell=True, capture_output=True, text=True) # noqa: PLW1510 +if "No subscriptions found" in result.stderr: + # Prompt the user to log in to Azure + print("You are not logged in to Azure. Please log in to continue.") + subprocess.run("az login", shell=True) # noqa: PLW1510 + +# Publish the package to Azure Artifacts if --no-upload is not specified + +cmd = f'az artifacts universal publish --organization https://dev.azure.com/onnxruntime --feed onnxruntime --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' +if args.do_upload: + subprocess.run(cmd, shell=True) # noqa: PLW1510 +else: + print("would have run: " + cmd) + +cmd = f'az artifacts universal publish --organization https://dev.azure.com/aiinfra --feed Lotus --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' +if args.do_upload: + subprocess.run(cmd, shell=True) # noqa: PLW1510 +else: + print("would have run: " + cmd) diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index 18ac668bb1..8c5d81d638 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -1,4 +1,4 @@ -if (onnxruntime_USE_FLASH_ATTENTION) +if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) include(FetchContent) FetchContent_Declare( cutlass diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index c087ad8f6d..8e412c7847 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -46,8 +46,8 @@ if (onnxruntime_BUILD_UNIT_TESTS) FetchContent_Declare( googletest URL ${DEP_URL_googletest} + FIND_PACKAGE_ARGS 1.13.0...<2.0.0 NAMES GTest URL_HASH SHA1=${DEP_SHA1_googletest} - OVERRIDE_FIND_PACKAGE ) endif() @@ -528,4 +528,3 @@ endif() FILE(TO_NATIVE_PATH ${CMAKE_BINARY_DIR} ORT_BINARY_DIR) FILE(TO_NATIVE_PATH ${PROJECT_SOURCE_DIR} ORT_SOURCE_DIR) - diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 5adfc7ba03..03360ff30c 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -529,7 +529,7 @@ if (onnxruntime_USE_CUDA) target_link_libraries(${target} PRIVATE cuda) endif() - if (onnxruntime_USE_FLASH_ATTENTION) + if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) include(cutlass) target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) endif() diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 982912fb12..bf9adbaefa 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -466,6 +466,9 @@ file(GLOB onnxruntime_python_transformers_models_bert_src CONFIGURE_DEPENDS file(GLOB onnxruntime_python_transformers_models_gpt2_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/gpt2/*.py" ) +file(GLOB onnxruntime_python_transformers_models_llama_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/llama/*.py" +) file(GLOB onnxruntime_python_transformers_models_longformer_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/longformer/*.py" ) @@ -537,6 +540,7 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/bart COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/bert COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/gpt2 + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/llama COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/longformer COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/stable_diffusion COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/t5 @@ -628,6 +632,9 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_gpt2_src} $/onnxruntime/transformers/models/gpt2/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_models_llama_src} + $/onnxruntime/transformers/models/llama/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_longformer_src} $/onnxruntime/transformers/models/longformer/ diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index c8592a4019..64c4ca7d15 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -10,6 +10,9 @@ set(contrib_ops_excluded_files "bert/attention_impl.cu" "bert/attention_softmax.h" "bert/attention_softmax.cu" + "bert/attention_prepare_qkv.cu" + "bert/decoder_attention_impl.h" + "bert/decoder_attention_impl.cu" "bert/decoder_masked_multihead_attention.h" "bert/decoder_masked_multihead_attention.cc" "bert/decoder_masked_self_attention.h" @@ -58,6 +61,16 @@ set(contrib_ops_excluded_files "quantization/attention_quantization.h" "quantization/attention_quantization_impl.cu" "quantization/attention_quantization_impl.cuh" + "quantization/dequantize_blockwise.cuh" + "quantization/dequantize_blockwise.cu" + "quantization/dequantize_blockwise_bnb4.cuh" + "quantization/dequantize_blockwise_bnb4.cu" + "quantization/matmul_bnb4.cc" + "quantization/matmul_bnb4.cuh" + "quantization/matmul_bnb4.cu" + "quantization/matmul_nbits.cc" + "quantization/matmul_nbits.cuh" + "quantization/matmul_nbits.cu" "quantization/quantize_dequantize_linear.cc" "quantization/qordered_ops/qordered_attention_impl.cu" "quantization/qordered_ops/qordered_attention_impl.h" @@ -100,6 +113,11 @@ set(contrib_ops_excluded_files "cuda_contrib_kernels.h" "inverse.cc" "fused_conv.cc" + "bert/group_query_attention_helper.h" + "bert/group_query_attention.h" + "bert/group_query_attention.cc" + "bert/group_query_attention_impl.h" + "bert/group_query_attention_impl.cu" ) if (NOT onnxruntime_ENABLE_ATEN) @@ -201,6 +219,10 @@ set(training_ops_excluded_files "reduction/reduction_ops.cc" # no double type support "cuda_training_kernels.cc" "cuda_training_kernels.h" + "nn/conv_shared.cc" + "nn/conv_shared.h" + "nn/conv_transpose_grad.cc" + "nn/conv_transpose_grad.h" ) function(auto_set_source_files_hip_language) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 2ba837be22..f722ca9d30 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -1860,54 +1860,61 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DOrtFillStringTensor OrtFillStringTensor; + /// \param value A tensor created from OrtCreateTensor... function. + /// \param index The index of the entry in the tensor to resize. + /// \param length_in_bytes Length to resize the string to. + /// \param buffer The resized buffer. + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetResizedStringTensorElementBuffer( - IntPtr /* OrtValue */ value, - UIntPtr /* size_t */ index, - UIntPtr /* size_t */ length_in_bytes, - out IntPtr /* char** */ buffer - ); + IntPtr /* OrtValue */ value, + UIntPtr /* size_t */ index, + UIntPtr /* size_t */ length_in_bytes, + out IntPtr /* char** */ buffer); public static DOrtGetResizedStringTensorElementBuffer OrtGetResizedStringTensorElementBuffer; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorContent( - IntPtr /*(OrtValue*)*/ value, - byte[] /*(void*)*/ dst_buffer, - UIntPtr dst_buffer_len, - UIntPtr[] offsets, - UIntPtr offsets_len); + IntPtr /*(OrtValue*)*/ value, + byte[] /*(void*)*/ dst_buffer, + UIntPtr dst_buffer_len, + UIntPtr[] offsets, + UIntPtr offsets_len); public static DOrtGetStringTensorContent OrtGetStringTensorContent; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorDataLength(IntPtr /*(OrtValue*)*/ value, - out UIntPtr /*(size_t*)*/ len); + out UIntPtr /*(size_t*)*/ len); public static DOrtGetStringTensorDataLength OrtGetStringTensorDataLength; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorElementLength(IntPtr /*(OrtValue*)*/ value, - UIntPtr /*(size_t)*/ index, - out UIntPtr /*(size_t*)*/ len); + UIntPtr /*(size_t)*/ index, + out UIntPtr /*(size_t*)*/ len); public static DOrtGetStringTensorElementLength OrtGetStringTensorElementLength; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorElement(IntPtr /*(OrtValue*)*/ value, - UIntPtr /*(size_t)*/ bufferLength, - UIntPtr /*(size_t)*/ elementIndex, - byte[] buffer); + UIntPtr /*(size_t)*/ bufferLength, + UIntPtr /*(size_t)*/ elementIndex, + byte[] buffer); public static DOrtGetStringTensorElement OrtGetStringTensorElement; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ - DOrtCastTypeInfoToTensorInfo(IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, out IntPtr /*(const struct OrtTensorTypeAndShapeInfo**)*/ typeAndShapeInfo); + public delegate IntPtr /*(OrtStatus*)*/ DOrtCastTypeInfoToTensorInfo( + IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, + out IntPtr /*(const struct OrtTensorTypeAndShapeInfo**)*/ typeAndShapeInfo); public static DOrtCastTypeInfoToTensorInfo OrtCastTypeInfoToTensorInfo; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorTypeAndShape(IntPtr /*(OrtValue*)*/ value, out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo); + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorTypeAndShape( + IntPtr /*(OrtValue*)*/ value, + out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo); public static DOrtGetTensorTypeAndShape OrtGetTensorTypeAndShape; @@ -1917,12 +1924,16 @@ out IntPtr /* char** */ buffer public static DOrtReleaseTensorTypeAndShapeInfo OrtReleaseTensorTypeAndShapeInfo; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorElementType(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, out IntPtr /*(TensorElementType*)*/ output); + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorElementType( + IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, + out IntPtr /*(TensorElementType*)*/ output); public static DOrtGetTensorElementType OrtGetTensorElementType; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensionsCount(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, out UIntPtr output); + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensionsCount( + IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, + out UIntPtr output); public static DOrtGetDimensionsCount OrtGetDimensionsCount; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs index b374371446..86b44a6784 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs @@ -743,7 +743,7 @@ internal static OrtValue CreateFromTensorObject(TensorBase value, out TensorElem /// /// Creates an OrtValue that contains a string tensor of specified shape, and /// containing empty strings. String tensors are always on CPU. - /// Use FillStringTensorElement to assign individual elements values. + /// Use StringTensorSetElementAt to assign individual elements values. /// /// /// disposable OrtValue diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs index 659c630370..6889112acb 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs @@ -40,20 +40,16 @@ internal enum PropertyType : long String = 2 } - private void AddPropertyImpl(string propertyName, PropertyType propertyType, T propertyValue) + private void AddPropertyImpl(string propertyName, PropertyType propertyType, T propertyValue) where T : unmanaged { var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); - T[] value = new T[1]; - value[0] = propertyValue; - Memory memory = value; - using (var memHandle = memory.Pin()) + T[] value = { propertyValue }; + unsafe { - IntPtr memPtr; - unsafe + fixed (T* memPtr = value) { - memPtr = (IntPtr)memHandle.Pointer; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, (IntPtr)memPtr)); } - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, memPtr)); } } @@ -103,13 +99,13 @@ public static void SaveCheckpoint(CheckpointState state, string checkpointPath, } /// - /// Adds the given int property to the checkpoint state. + /// Adds or updates the given int property to/in the checkpoint state. /// - /// Runtime properties that are ints such as epoch, training step, and others can be added to the checkpoint - /// state by the user if they desire by calling this function with the appropriate property name and - /// value. The given property name must be unique to be able to successfully add the property. + /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + /// state by the user by calling this function with the corresponding property name and value. + /// The given property name must be unique to be able to successfully add the property. /// - /// Unique name of the property being added. + /// Name of the property being added or updated. /// Property value associated with the given name. public void AddProperty(string propertyName, long propertyValue) { @@ -117,13 +113,13 @@ public void AddProperty(string propertyName, long propertyValue) } /// - /// Adds the given float property to the checkpoint state. + /// Adds or updates the given float property to/in the checkpoint state. /// - /// Runtime properties that are floats such as loss, best score, and others can be added to the checkpoint - /// state by the user if they desire by calling this function with the appropriate property name and - /// value. The given property name must be unique to be able to successfully add the property. + /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + /// state by the user by calling this function with the corresponding property name and value. + /// The given property name must be unique to be able to successfully add the property. /// - /// Unique name of the property being added. + /// Name of the property being added or updated. /// Property value associated with the given name. public void AddProperty(string propertyName, float propertyValue) { @@ -131,28 +127,25 @@ public void AddProperty(string propertyName, float propertyValue) } /// - /// Adds the given string property to the checkpoint state. + /// Adds or updates the given string property to/in the checkpoint state. /// - /// Runtime properties that are strings such as parameter names, custom strings, and others can be added - /// to the checkpoint state by the user if they desire by calling this function with the appropriate property - /// name and value. The given property name must be unique to be able to successfully add the property. + /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + /// state by the user by calling this function with the corresponding property name and value. + /// The given property name must be unique to be able to successfully add the property. /// - /// Unique name of the property being added. + /// Name of the property being added or updated. /// Property value associated with the given name. public void AddProperty(string propertyName, string propertyValue) { var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); var propertyValueUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyValue); - IntPtr unmanagedPointer = Marshal.AllocHGlobal(propertyValueUtf8.Length); - try - { - Marshal.Copy(propertyValueUtf8, 0, unmanagedPointer, propertyValueUtf8.Length); - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, unmanagedPointer)); - } - finally + unsafe { - Marshal.FreeHGlobal(unmanagedPointer); + fixed (byte* p = propertyValueUtf8) + { + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, (IntPtr)p)); + } } } @@ -162,34 +155,86 @@ public void AddProperty(string propertyName, string propertyValue) /// Gets the property value from an existing entry in the checkpoint state. The property must /// exist in the checkpoint state to be able to retrieve it successfully. /// - /// Unique name of the property being retrieved. + /// Name of the property being retrieved. /// Property value associated with the given property name. public object GetProperty(string propertyName) { var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); var allocator = OrtAllocator.DefaultInstance; IntPtr propertyValue = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetProperty(handle, propertyNameUtf8, allocator.Pointer, out PropertyType propertyType, out propertyValue)); - if (propertyType == PropertyType.Int) + try { - var longPropertyValue = Marshal.ReadInt64(propertyValue); - allocator.FreeMemory(propertyValue); - return longPropertyValue; + if (propertyType == PropertyType.Int) + { + Int64 value; + unsafe + { + value = *(Int64*)propertyValue; + } + return value; + } + else if (propertyType == PropertyType.Float) + { + float value; + unsafe + { + value = *(float*)propertyValue; + } + return value; + } + else if (propertyType == PropertyType.String) + { + return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue); + } + + throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString()); } - else if (propertyType == PropertyType.Float) + finally { - float[] value = new float[1]; - Marshal.Copy(propertyValue, value, 0, 1); allocator.FreeMemory(propertyValue); - return value[0]; } - else if (propertyType == PropertyType.String) + } + + /// + /// Updates the data associated with the model parameter in the checkpoint state for the given parameter name. + /// + /// This function updates a model parameter in the checkpoint state with the given parameter data. + /// The training session must be already created with the checkpoint state that contains the parameter + /// being updated. The given parameter is copied over to the registered device for the training session. + /// The parameter must exist in the checkpoint state to be able to update it successfully. + /// + /// Name of the parameter being updated. + /// The parameter data that should replace the existing parameter data. + public void UpdateParameter(string parameterName, OrtValue parameter) + { + if (parameter.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR) { - return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue, allocator); + throw new ArgumentException("Incorrect buffer received. Expected a tensor parameter."); } - throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString()); + var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtUpdateParameter(handle, parameterNameUtf8, parameter.Handle)); + } + + /// + /// Gets the data associated with the model parameter from the checkpoint state for the given parameter name. + /// + /// This function retrieves the model parameter data from the checkpoint state for the given parameter name. + /// The parameter is copied over to the provided OrtValue. The training session must be already created + /// with the checkpoint state that contains the parameter being retrieved. + /// The parameter must exist in the checkpoint state to be able to retrieve it successfully. + /// + /// Name of the parameter being updated. + /// The parameter data that is retrieved from the checkpoint state. + public OrtValue GetParameter(string parameterName) + { + var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameter(handle, parameterNameUtf8, OrtAllocator.DefaultInstance.Pointer, out IntPtr parameterHandle)); + + return new OrtValue(parameterHandle); } #region SafeHandle diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index c52ca4d1a4..d6341b90f2 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -15,6 +15,7 @@ public struct OrtTrainingApi public IntPtr LoadCheckpoint; public IntPtr SaveCheckpoint; public IntPtr CreateTrainingSession; + public IntPtr CreateTrainingSessionFromBuffer; public IntPtr TrainingSessionGetTrainingModelOutputCount; public IntPtr TrainingSessionGetEvalModelOutputCount; public IntPtr TrainingSessionGetTrainingModelOutputName; @@ -41,6 +42,9 @@ public struct OrtTrainingApi public IntPtr AddProperty; public IntPtr GetProperty; public IntPtr LoadCheckpointFromBuffer; + public IntPtr GetParameterTypeAndShape; + public IntPtr UpdateParameter; + public IntPtr GetParameter; } internal static class NativeTrainingMethods @@ -96,6 +100,9 @@ static NativeTrainingMethods() OrtGetEvalModelInputName = (DOrtGetEvalModelInputName)Marshal.GetDelegateForFunctionPointer(trainingApi_.TrainingSessionGetEvalModelInputName, typeof(DOrtGetEvalModelInputName)); OrtAddProperty = (DOrtAddProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.AddProperty, typeof(DOrtAddProperty)); OrtGetProperty = (DOrtGetProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetProperty, typeof(DOrtGetProperty)); + OrtGetParameterTypeAndShape = (DOrtGetParameterTypeAndShape)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameterTypeAndShape, typeof(DOrtGetParameterTypeAndShape)); + OrtUpdateParameter = (DOrtUpdateParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.UpdateParameter, typeof(DOrtUpdateParameter)); + OrtGetParameter = (DOrtGetParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameter, typeof(DOrtGetParameter)); } } @@ -358,6 +365,34 @@ out UIntPtr inputCount public static DOrtGetProperty OrtGetProperty; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameterTypeAndShape( + IntPtr /*(OrtCheckpointState*)*/ checkpointState, + byte[] /*(const char*)*/ parameterName, + out IntPtr /*(OrtTensorTypeAndShapeInfo**)*/ parameterTypeAndShape + ); + + public static DOrtGetParameterTypeAndShape OrtGetParameterTypeAndShape; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtUpdateParameter( + IntPtr /*(OrtCheckpointState*)*/ checkpointState, + byte[] /*(const char*)*/ parameterName, + IntPtr /*(OrtValue*)*/ parameter + ); + + public static DOrtUpdateParameter OrtUpdateParameter; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameter( + IntPtr /*(OrtCheckpointState*)*/ checkpointState, + byte[] /*(const char*)*/ parameterName, + IntPtr /*(OrtAllocator*)*/ allocator, + out IntPtr /*(OrtValue**)*/ parameter + ); + + public static DOrtGetParameter OrtGetParameter; + #endregion TrainingSession API public static bool TrainingEnabled() diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs index 33993c2be1..877677dcad 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs @@ -358,13 +358,14 @@ public void EvalStep( IReadOnlyCollection inputValues, IReadOnlyCollection outputValues) { - if (!_evalOutputCount.Equals(outputValues.Count)) + if (_evalOutputCount != (ulong)outputValues.Count()) { - throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of train model ({_trainOutputCount})."); + throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of eval model ({_evalOutputCount})."); } - IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true); + const bool isInput = true; + IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, isInput); - IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); /* pointers to Pre-allocated OrtValue instances */ + IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, !isInput); /* pointers to Pre-allocated OrtValue instances */ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count, inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray)); } @@ -509,18 +510,17 @@ public void ExportModelForInferencing(string inferenceModelPath, IReadOnlyCollec /// Returns a contiguous buffer that holds a copy of all training state parameters /// /// Whether to only copy trainable parameters or to copy all parameters. - public FixedBufferOnnxValue ToBuffer(bool onlyTrainable) + public OrtValue ToBuffer(bool onlyTrainable) { UIntPtr bufferSize = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, onlyTrainable)); float[] bufferMemory = new float[bufferSize.ToUInt64()]; - var memInfo = OrtMemoryInfo.DefaultInstance; // CPU - var shape = new long[] { (long)bufferSize.ToUInt64() }; - var buffer = FixedBufferOnnxValue.CreateFromMemory(memInfo, bufferMemory, Tensors.TensorElementType.Float, shape, (long)bufferSize.ToUInt64() * sizeof(float)); + var shape = new long[] { (long)bufferSize }; + var buffer = OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, Tensors.TensorElementType.Float, shape); - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Value.Handle, onlyTrainable)); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Handle, onlyTrainable)); return buffer; } @@ -528,45 +528,30 @@ public FixedBufferOnnxValue ToBuffer(bool onlyTrainable) /// /// Loads the training session model parameters from a contiguous buffer /// - /// Contiguous buffer to load the parameters from. - public void FromBuffer(FixedBufferOnnxValue buffer) + /// Contiguous buffer to load the parameters from. + /// Whether to only load trainable parameters or to load all parameters. + public void FromBuffer(OrtValue ortValue, bool onlyTrainable) { - if (buffer.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR) + if (ortValue.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR) { throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer."); } - IntPtr typeAndShapeInfo = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(buffer.Value.Handle, out typeAndShapeInfo)); - UIntPtr numDimensions = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(typeAndShapeInfo, out numDimensions)); - if (numDimensions.ToUInt64() != 1) + var tensorInfo = ortValue.GetTensorTypeAndShape(); + if (tensorInfo.ElementDataType != Tensors.TensorElementType.Float) { - string errorMessage = "Incorrect buffer shape received. Expected a contiguous tensor buffer. Expected number of dimensions: 1, Actual: " + numDimensions.ToString(); - throw new ArgumentException(errorMessage); - } - - // Here buffer size represents the number of elements in the buffer - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out UIntPtr bufferSize)); - - // OrtGetParametersSize returns the total number of elements in the model's parameters. - UIntPtr numElementsTrainingOnly = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElementsTrainingOnly, true)); - if ((ulong)bufferSize == (ulong)numElementsTrainingOnly) - { - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, true)); - return; + throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer of type float."); } UIntPtr numElements = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, false)); - if ((ulong)bufferSize != (ulong)numElements) + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, onlyTrainable)); + if ((ulong)tensorInfo.ElementCount != (ulong)numElements) { - string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString(); + string errorMessage = "Incorrect buffer size received. Expected size to be " + numElements.ToString() + ". Actual size: " + tensorInfo.ElementCount.ToString(); throw new ArgumentException(errorMessage); } - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, false)); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, ortValue.Handle, onlyTrainable)); } /// diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs index ea2b6d7dbc..68b1d5bcc6 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs @@ -484,20 +484,23 @@ public void TestEvalModelOutputNames() public void TestToBuffer() { string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); - using (var cleanUp = new DisposableListTest()) + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) { - var state = CheckpointState.LoadCheckpoint(checkpointPath); - cleanUp.Add(state); Assert.NotNull(state); - string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); - string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); - string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); - var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); - cleanUp.Add(trainingSession); - - var buffer = trainingSession.ToBuffer(true); - cleanUp.Add(buffer); + using (var buffer = trainingSession.ToBuffer(true)) + { + Assert.NotNull(buffer); + var typeShape = buffer.GetTensorTypeAndShape(); + Assert.Equal(1, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(397510, fetchedShape[0]); + } } } @@ -505,22 +508,25 @@ public void TestToBuffer() public void TestFromBuffer() { string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); - using (var cleanUp = new DisposableListTest()) + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) { - var state = CheckpointState.LoadCheckpoint(checkpointPath); - cleanUp.Add(state); Assert.NotNull(state); - string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); - string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); - string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); - - var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); - cleanUp.Add(trainingSession); - var buffer = trainingSession.ToBuffer(true); - cleanUp.Add(buffer); + using (var buffer = trainingSession.ToBuffer(true)) + { + Assert.NotNull(buffer); + var typeShape = buffer.GetTensorTypeAndShape(); + Assert.Equal(1, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(397510, fetchedShape[0]); - trainingSession.FromBuffer(buffer); + trainingSession.FromBuffer(buffer, true); + } } } @@ -530,6 +536,82 @@ public void TestSetSeed() TrainingUtils.SetSeed(8888); } + [Fact(DisplayName = "TestGetParameter")] + public void TestGetParameter() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) + using (var parameter = state.GetParameter("fc1.weight")) + { + Assert.NotNull(state); + Assert.NotNull(parameter); + + var typeShape = parameter.GetTensorTypeAndShape(); + Assert.Equal(2, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(500, fetchedShape[0]); + Assert.Equal(784, fetchedShape[1]); + } + } + + [Fact(DisplayName = "TestUpdateParameter")] + public void TestUpdateParameter() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) + { + Assert.NotNull(state); + + using (var parameter = state.GetParameter("fc1.weight")) + { + Assert.NotNull(parameter); + var typeShape = parameter.GetTensorTypeAndShape(); + + Assert.Equal(2, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(500, fetchedShape[0]); + Assert.Equal(784, fetchedShape[1]); + + float maxVal = 20; + Random randNum = new Random(); + float[] updated_parameter_buffer = Enumerable + .Repeat(0, 500 * 784) + .Select(i => maxVal * (float)randNum.NextDouble()) + .ToArray(); + + using (var updated_parameter = OrtValue.CreateTensorValueFromMemory(updated_parameter_buffer, fetchedShape)) + { + state.UpdateParameter("fc1.weight", updated_parameter); + using (var current_parameter = state.GetParameter("fc1.weight")) + { + var current_parameter_tensor = current_parameter.GetTensorDataAsSpan().ToArray(); + Assert.Equal(updated_parameter_buffer, current_parameter_tensor); + Assert.NotEqual(parameter.GetTensorDataAsSpan().ToArray(), current_parameter_tensor); + } + + state.UpdateParameter("fc1.weight", parameter); + + using (var current_parameter = state.GetParameter("fc1.weight")) + { + var current_parameter_tensor = current_parameter.GetTensorDataAsSpan().ToArray(); + Assert.Equal(parameter.GetTensorDataAsSpan().ToArray(), current_parameter_tensor); + Assert.NotEqual(updated_parameter_buffer, current_parameter_tensor); + } + } + } + } + } + internal class FloatComparer : IEqualityComparer { private float atol = 1e-3f; diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 63a7289dd9..7778a4d369 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -41,12 +41,15 @@ Do not modify directly.* * com.microsoft.GreedySearch * com.microsoft.GridSample * com.microsoft.GroupNorm + * com.microsoft.GroupQueryAttention * com.microsoft.Inverse * com.microsoft.Irfft * com.microsoft.LongformerAttention + * com.microsoft.MatMulBnb4 * com.microsoft.MatMulFpQ4 * com.microsoft.MatMulInteger16 * com.microsoft.MatMulIntegerToFloat + * com.microsoft.MatMulNBits * com.microsoft.MaxpoolWithMask * com.microsoft.MulInteger * com.microsoft.MultiHeadAttention @@ -86,8 +89,10 @@ Do not modify directly.* * com.microsoft.RemovePadding * com.microsoft.RestorePadding * com.microsoft.Rfft + * com.microsoft.RotaryEmbedding * com.microsoft.SampleOp * com.microsoft.Sampling + * com.microsoft.SkipGroupNorm * com.microsoft.SkipLayerNormalization * com.microsoft.SkipSimplifiedLayerNormalization * com.microsoft.Snpe @@ -1169,9 +1174,9 @@ This version of the operator has been available since version 1 of the 'com.micr
output : T
3D output tensor with shape (batch_size, sequence_length, v_hidden_size)
present_key (optional) : T
-
past state for key with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
+
present state for key with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
present_value (optional) : T
-
past state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
+
present state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
#### Type Constraints @@ -2181,7 +2186,7 @@ This version of the operator has been available since version 1 of the 'com.micr
activation : int (required)
-
Activation after group normalization: 0 for None, 1 for Swish
+
Activation after group normalization: 0 for None, 1 for SiLU
channels_last : int
1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.
epsilon : float
@@ -2218,6 +2223,67 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.GroupQueryAttention** + + Group Query Self/Cross Attention. + + Supports different number of heads for q and kv. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
kv_num_heads : int (required)
+
Number of attention heads for k and v
+
num_heads : int (required)
+
Number of attention heads for q
+
scale : float
+
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
+
+ +#### Inputs + +
+
query : T
+
Query with shape (batch_size, sequence_length, hidden_size)
+
key : T
+
Key with shape (batch_size, kv_sequence_length, kv_hidden_size)
+
value : T
+
Value with shape (batch_size, kv_sequence_length, kv_hidden_size)
+
past_key (optional) : T
+
past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
+
past_value (optional) : T
+
past state value with support for format BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
+
seqlens_k : M
+
1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.
+
total_sequence_length : M
+
Scalar tensor of total sequence length (past + new).
+
+ +#### Outputs + +
+
output : T
+
3D output tensor with shape (batch_size, sequence_length, hidden_size)
+
present_key : T
+
present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
+
present_value : T
+
present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
+
+ +#### Type Constraints + +
+
T : tensor(float16)
+
Constrain input and output to float tensors.
+
M : tensor(int32)
+
Constrain mask to int tensor.
+
+ + ### **com.microsoft.Inverse** #### Version @@ -2347,6 +2413,63 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.MatMulBnb4** + + MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: + 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. + 2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'. + And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. + 3. Input B's quantization constants or scales are specified by input 'absmax'. + + Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. + Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
K : int (required)
+
size of each input feature
+
N : int (required)
+
size of each output feature
+
block_size : int (required)
+
number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.
+
quant_type : int (required)
+
quantization data type. 0 for FP4, 1 for NF4.
+
+ +#### Inputs + +
+
A : T1
+
The input tensor, not quantized
+
B : T2
+
1-dimensional quantized data for weight
+
absmax : T1
+
quantization constants
+
+ +#### Outputs + +
+
Y : T1
+
tensor. The output tensor has the same rank as the input.
+
+ +#### Type Constraints + +
+
T1 : tensor(float), tensor(float16)
+
Constrain input and output types to float/half_float tensors.
+
T2 : tensor(uint8)
+
Constrain quantized weight types to uint8.
+
+ + ### **com.microsoft.MatMulFpQ4** Matrix product with right hand matrix being pre-packed and quantized int4 data blob. @@ -2479,6 +2602,78 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.MatMulNBits** + + MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7).It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: + 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. + 2. Input B is quantized with x bits which is specified by attribute 'bits'. It is quantized blockwisely along dimension 0 (e.g. column) with block size specified by attribute block_size. + And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. + 3. Input B's scale and zero point are specified by input scales and zero_points. + + Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: + - n_blocks_per_col = (K + block_size - 1) / block_size + - blob_size = block_size / 8 * bits + + For a block blob. It is stored in format: + struct Blob { + uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization + uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization + uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization + } + + Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] + Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: + - [(N * n_blocks_per_col + 1) / 2] if bits <=4 + - [N * n_blocks_per_col] if bits > 4 + + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
K : int (required)
+
size of each input feature
+
N : int (required)
+
size of each output feature
+
bits : int (required)
+
number of bits used for weight quantization (default 4)
+
block_size : int (required)
+
number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.
+
+ +#### Inputs (3 - 4) + +
+
A : T1
+
The input tensor, not quantized
+
B : T2
+
1-dimensional data blob
+
scales : T1
+
quantization scale
+
zero_points (optional) : T2
+
quantization zero points
+
+ +#### Outputs + +
+
Y : T1
+
tensor. The output tensor has the same rank as the input.
+
+ +#### Type Constraints + +
+
T1 : tensor(float), tensor(float16)
+
Constrain input and output types to float/half_float tensors.
+
T2 : tensor(uint8)
+
Constrain quantized weight types to uint8.
+
+ + ### **com.microsoft.MaxpoolWithMask** For internal use. @@ -2606,7 +2801,7 @@ This version of the operator has been available since version 1 of the 'com.micr
bias (optional) : T
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
key_padding_mask (optional) : M
-
Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)
+
Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), or (batch_size, sequence_length, total_sequence_length)
relative_position_bias (optional) : T
relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)
past_key (optional) : T
@@ -4567,6 +4762,54 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.RotaryEmbedding** + + RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices + that are multiplied to query and key before the inner product of query and key is taken. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
interleaved : int
+
Rotate using interleaved pattern. Default value is 0 (False).
+
scale : float
+
Custom scale will be used if specified. Default value is 1.0
+
+ +#### Inputs + +
+
input : T
+
3D tensor with shape (batch_size, sequence_length, hidden_size)
+
position_ids : M
+
1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)
+
cos_cache : T
+
2D tensor with shape (max_sequence_length, head_size / 2).
+
sin_cache : T
+
2D tensor with shape (max_sequence_length, head_size / 2).
+
+ +#### Outputs + +
+
output : T
+
3D tensor with shape (batch_size, sequence_length, hidden_size)
+
+ +#### Type Constraints + +
+
T : tensor(float), tensor(float16)
+
Constrain input and output types to float tensors.
+
M : tensor(int64)
+
Constrain input and output types to integer tensors
+
+ + ### **com.microsoft.SampleOp** Sample echo operator. @@ -4682,6 +4925,72 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.SkipGroupNorm** + + This operator element-wise adds x, skip and bias, then apply group normalization and optional activation. + + This operator transforms input according to + s = x + skip + bias + y = gamma * (s - mean) / sqrt(variance + epsilon) + beta + + The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. + The num_channels must be divisible by num_groups. + The mean and standard-deviation of s are calculated separately over the each group. + The weight and bias are per-channel affine transform parameter vectors of size num_channels. + + The activation attribute can be used to enable activation after group normalization. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
activation : int (required)
+
Activation after group normalization: 0 for None, 1 for SiLU
+
channels_last : int
+
1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.
+
epsilon : float
+
The epsilon value to use to avoid division by zero
+
groups : int (required)
+
The number of groups of channels. It should be a divisor of the number of channels C
+
+ +#### Inputs (4 - 5) + +
+
X : T
+
Input data tensor. Dimensions are (N x H x W x C) when channels_last is 1 or (N x C x H x W) otherwise, where N is the batch size, C is the number of channels, and H and W are the height and width of the data
+
gamma : M
+
1D gamma tensor for normalization with shape (C), where C is number of channels
+
beta : M
+
1D beta tensor for normalization with shape (C), where C is number of channels
+
skip : T
+
4D or 2D skip tensor. The shape can be (N x H x W x C) or (N x 1 x 1 x C) or (N x C)
+
bias (optional) : T
+
1D bias tensor. Dimensions are (C), where C is number of channels
+
+ +#### Outputs (1 - 2) + +
+
Y : T
+
The output tensor of the same shape as X
+
S (optional) : T
+
The element-wise sum of input x, skip and bias tensors. It has the same shape as X
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(float)
+
Constrain input X, skip, bias and output Y, S types to float tensors.
+
M : tensor(float16), tensor(float)
+
Constrain gamma and beta to float tensors.
+
+ + ### **com.microsoft.SkipLayerNormalization** Skip and Layer Normalization Fusion diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index c76f760ef0..f4142adc07 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -453,9 +453,11 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| |MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)| |MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)| @@ -475,9 +477,11 @@ Do not modify directly.* |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float)| |SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)| +|SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)| |SparseToDenseMatMul|*in* A:**T**
*in* B:**T1**
*out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |Tokenizer|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(string)| |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| @@ -759,6 +763,7 @@ Do not modify directly.* |Shrink|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Sigmoid|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| +|Sign|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)
**V** = tensor(double), tensor(float), tensor(float16)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)| |Size|*in* data:**T**
*out* size:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| @@ -838,9 +843,12 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -859,7 +867,9 @@ Do not modify directly.* |RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)| |RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)| +|SkipGroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*in* skip:**T**
*in* bias:**T**
*out* Y:**T**
*out* S:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| @@ -1243,6 +1253,7 @@ Do not modify directly.* |QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float), tensor(float16), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| | | | | diff --git a/docs/python/README.rst b/docs/python/README.rst index 7d978b0941..bcf7c635af 100644 --- a/docs/python/README.rst +++ b/docs/python/README.rst @@ -8,6 +8,16 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime ((b & 0x80000000) >> 24); // sign - if ((b & 0x7fc00000) == 0x7fc00000) { - val |= 0x7f; - } else if ((b & 0x7fffffff) == 0x7f800000) { + if ((b & 0x7fffffff) == 0x7f800000) { // infinity if (saturate) { val |= 126; } else { val |= 0x7f; } + } else if ((b & 0x7F800000) == 0x7F800000) { // NaN + val |= 0x7f; } else { uint8_t e = static_cast((b & 0x7F800000) >> 23); // exponent uint32_t m = static_cast(b & 0x007FFFFF); // mantissa if (e != 0) { - if (e < 117) { // 0b1110101 - } else if (e < 118) { // 0b1110110 - val |= 1; - if ((m >> 23) & 1) { - // rounding - val += 1; + if (e < 117) { + } else if (e < 121) { + // denormalized number + auto d = 120 - e; + if (d < 3) { + val |= 1 << (2 - d); + val |= m >> (21 + d); + } else if (m > 0) { + val |= 1; } - } else if (e < 121) { // 127 - 7 + 1 // 0b1111001 - auto d = 120 - e; // 0b1111000 - val |= 1 << (2 - d); - val |= m >> (21 + d); - if ((m >> (20 + d)) & 1) { + auto mask = 1 << (20 + d); + if ((m & mask) && + ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { // rounding val += 1; } - } else if (e < 136) { // 127 + 8 + 1 // 0b10001000 - auto ex = e - 120; // 127 - 7 + } else if (e < 136) { + // normalized number + auto ex = e - 120; if (ex == 0) { val |= 0x4; val |= m >> 21; @@ -83,7 +85,7 @@ struct Float8E4M3FN { val &= 0xFE; } } - if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7C000))) { + if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) { if ((val & 0x7F) < 0x7E) { // rounding val += 1; @@ -147,14 +149,22 @@ struct Float8E4M3FN { inline ORT_HOST_DEVICE operator float() const { return ToFloat(); } #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 - explicit ORT_HOST_DEVICE Float8E4M3FN(const __nv_fp8_e4m3& value) { val = *reinterpret_cast(&value); } + explicit ORT_HOST_DEVICE Float8E4M3FN(const __nv_fp8_e4m3& value) { + val = *reinterpret_cast(&value); + } explicit ORT_HOST_DEVICE operator __nv_fp8_e4m3() const { return *reinterpret_cast(&val); } #endif }; -inline ORT_HOST_DEVICE bool operator==(const Float8E4M3FN& left, const Float8E4M3FN& right) { return left.val == right.val; } -inline ORT_HOST_DEVICE bool operator!=(const Float8E4M3FN& left, const Float8E4M3FN& right) { return left.val != right.val; } -inline ORT_HOST_DEVICE bool operator<(const Float8E4M3FN& left, const Float8E4M3FN& right) { return left.val < right.val; } +inline ORT_HOST_DEVICE bool operator==(const Float8E4M3FN& left, const Float8E4M3FN& right) { + return left.val == right.val; +} +inline ORT_HOST_DEVICE bool operator!=(const Float8E4M3FN& left, const Float8E4M3FN& right) { + return left.val != right.val; +} +inline ORT_HOST_DEVICE bool operator<(const Float8E4M3FN& left, const Float8E4M3FN& right) { + return left.val < right.val; +} // User defined suffixes to make it easier to declare // initializers with MLFloat8E4M3FN and Float8E4M3FN from unsigned char @@ -164,9 +174,7 @@ inline Float8E4M3FN operator"" _f8e4m3fn(unsigned long long int v) { return Float8E4M3FN(narrow(v), Float8E4M3FN::FromBits()); } -inline Float8E4M3FN operator"" _f8e4m3fnp8(long double v) { - return Float8E4M3FN(static_cast(v), true); -} +inline Float8E4M3FN operator"" _f8e4m3fnp8(long double v) { return Float8E4M3FN(static_cast(v), true); } #endif @@ -205,36 +213,38 @@ struct Float8E4M3FNUZ { std::memcpy(&b, &v, sizeof(b)); val = static_cast((b & 0x80000000) >> 24); // sign - if ((b & 0x7fc00000) == 0x7fc00000) { - val = 0x80; - } else if ((b & 0x7fffffff) == 0x7f800000) { + if ((b & 0x7fffffff) == 0x7f800000) { // infinity if (saturate) { val |= 0x7F; } else { // infinity val = 0x80; } + } else if ((b & 0x7F800000) == 0x7F800000) { // NaN + val = 0x80; } else { uint8_t e = static_cast((b & 0x7F800000) >> 23); // exponent uint32_t m = static_cast(b & 0x007FFFFF); // mantissa if (e != 0) { if (e < 116) { - } else if (e < 117) { - val |= 1; - if ((m >> 23) & 1) { - // rounding - val += 1; - } - } else if (e < 120) { // 127 - 8 + 1 + } else if (e < 120) { + // denormalized number auto d = 119 - e; - val |= 1 << (2 - d); - val |= m >> (21 + d); - if ((m >> (20 + d)) & 1) { + if (d < 3) { + val |= 1 << (2 - d); + val |= m >> (21 + d); + } else if (m > 0) { + val |= 1; + } + auto mask = 1 << (20 + d); + if ((m & mask) && + ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { // rounding val += 1; } - } else if (e < 135) { // 127 + 8 - auto ex = e - 119; // 127 - 7 + } else if (e < 135) { + // normalized number + auto ex = e - 119; if (ex == 0) { val |= 0x4; val |= m >> 21; @@ -242,7 +252,7 @@ struct Float8E4M3FNUZ { val |= ex << 3; val |= m >> 20; } - if (m & 0x80000) { + if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) { if ((val & 0x7F) < 0x7F) { // rounding val += 1; @@ -303,9 +313,15 @@ struct Float8E4M3FNUZ { inline ORT_HOST_DEVICE operator float() const { return ToFloat(); } }; -inline ORT_HOST_DEVICE bool operator==(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { return left.val == right.val; } -inline ORT_HOST_DEVICE bool operator!=(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { return left.val != right.val; } -inline ORT_HOST_DEVICE bool operator<(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { return left.val < right.val; } +inline ORT_HOST_DEVICE bool operator==(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { + return left.val == right.val; +} +inline ORT_HOST_DEVICE bool operator!=(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { + return left.val != right.val; +} +inline ORT_HOST_DEVICE bool operator<(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { + return left.val < right.val; +} // User defined suffixes to make it easier to declare // initializers with MLFloat8E4M3FN and Float8E4M3FN from unsigned char @@ -315,9 +331,7 @@ inline Float8E4M3FNUZ operator"" _f8e4m3p8fnuz(unsigned long long int v) { return Float8E4M3FNUZ(narrow(v), Float8E4M3FNUZ::FromBits()); } -inline Float8E4M3FNUZ operator"" _f8e4m3fnuzp8(long double v) { - return Float8E4M3FNUZ(static_cast(v), true); -} +inline Float8E4M3FNUZ operator"" _f8e4m3fnuzp8(long double v) { return Float8E4M3FNUZ(static_cast(v), true); } #endif @@ -357,32 +371,33 @@ struct Float8E5M2 { uint32_t b; std::memcpy(&b, &v, sizeof(b)); - val = (b & 0x80000000) >> 24; // sign - if ((b & 0x7fc00000) == 0x7fc00000) { - val |= 0x7f; - } else if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf + val = (b & 0x80000000) >> 24; // sign + if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf if (saturate) { val |= 0x7B; } else { val |= 0x7C; } + } else if ((b & 0x7F800000) == 0x7F800000) { // NaN + val |= 0x7f; } else { uint32_t e = (b & 0x7F800000) >> 23; // exponent uint32_t m = b & 0x007FFFFF; // mantissa if (e != 0) { if (e < 110) { - } else if (e < 111) { - val |= 1; - if ((m >> 23) & 1) { - // rounding - val += 1; - } - } else if (e < 113) { // 127 - 15 + 1 + } else if (e < 113) { + // denormalized number auto d = 112 - e; - val |= 1 << (1 - d); - val |= m >> (22 + d); - if ((m >> (21 + d)) & 1) { + if (d < 2) { + val |= 1 << (1 - d); + val |= m >> (22 + d); + } else if (m > 0) { + val |= 1; + } + auto mask = 1 << (21 + d); + if ((m & mask) && + ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { // rounding val += 1; } @@ -461,8 +476,12 @@ struct Float8E5M2 { #endif }; -inline ORT_HOST_DEVICE bool operator==(const Float8E5M2& left, const Float8E5M2& right) { return left.val == right.val; } -inline ORT_HOST_DEVICE bool operator!=(const Float8E5M2& left, const Float8E5M2& right) { return left.val != right.val; } +inline ORT_HOST_DEVICE bool operator==(const Float8E5M2& left, const Float8E5M2& right) { + return left.val == right.val; +} +inline ORT_HOST_DEVICE bool operator!=(const Float8E5M2& left, const Float8E5M2& right) { + return left.val != right.val; +} inline ORT_HOST_DEVICE bool operator<(const Float8E5M2& left, const Float8E5M2& right) { return left.val < right.val; } // User defined suffixes to make it easier to declare @@ -473,9 +492,7 @@ inline Float8E5M2 operator"" _f8e5m2fn(unsigned long long int v) { return Float8E5M2(narrow(v), Float8E5M2::FromBits()); } -inline Float8E5M2 operator"" _f8e5m2fnp8(long double v) { - return Float8E5M2(static_cast(v), true); -} +inline Float8E5M2 operator"" _f8e5m2fnp8(long double v) { return Float8E5M2(static_cast(v), true); } #endif @@ -513,40 +530,42 @@ struct Float8E5M2FNUZ { uint32_t b; std::memcpy(&b, &v, sizeof(b)); - val = (b & 0x80000000) >> 24; // sign - if ((b & 0x7fc00000) == 0x7fc00000) { - val = 0x80; - } else if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf + val = (b & 0x80000000) >> 24; // sign + if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf if (saturate) { val |= 0x7F; } else { val = 0x80; } + } else if ((b & 0x7F800000) == 0x7F800000) { // NaN + val = 0x80; } else { uint32_t e = (b & 0x7F800000) >> 23; // exponent uint32_t m = b & 0x007FFFFF; // mantissa if (e != 0) { if (e < 109) { - } else if (e < 110) { - val |= 1; - if ((m >> 23) & 1) { - // rounding - val += 1; - } - } else if (e < 112) { // 127 - 16 + 1 + } else if (e < 112) { + // denormalized number auto d = 111 - e; - val |= 1 << (1 - d); - val |= m >> (22 + d); - if ((m >> (21 + d)) & 1) { + if (d < 2) { + val |= 1 << (1 - d); + val |= m >> (22 + d); + } else if (m > 0) { + val |= 1; + } + auto mask = 1 << (21 + d); + if ((m & mask) && + ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { // rounding val += 1; } - } else if (e < 143) { // 127 + 15 + 1 + } else if (e < 143) { + // normalized number auto ex = e - 111; val |= ex << 2; val |= m >> 21; - if (m & 0x100000) { + if ((m & 0x100000) && ((m & 0xFFFFF) || (m & 0x200000))) { if ((val & 0x7F) < 0x7F) { // rounding val += 1; @@ -554,7 +573,7 @@ struct Float8E5M2FNUZ { val = 0x80; } } - } else if ((e == 255) && (m == 0)) { // inf + } else if ((e == 255) && (m == 0)) { val = 0x80; } else if (saturate) { val |= 0x7F; @@ -605,9 +624,15 @@ struct Float8E5M2FNUZ { inline ORT_HOST_DEVICE operator float() const { return ToFloat(); } }; -inline ORT_HOST_DEVICE bool operator==(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { return left.val == right.val; } -inline ORT_HOST_DEVICE bool operator!=(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { return left.val != right.val; } -inline ORT_HOST_DEVICE bool operator<(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { return left.val < right.val; } +inline ORT_HOST_DEVICE bool operator==(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { + return left.val == right.val; +} +inline ORT_HOST_DEVICE bool operator!=(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { + return left.val != right.val; +} +inline ORT_HOST_DEVICE bool operator<(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { + return left.val < right.val; +} // User defined suffixes to make it easier to declare // initializers with MLFloat8E5M2 and Float8E5M2 from unsigned char @@ -617,9 +642,7 @@ inline Float8E5M2FNUZ operator"" _f8e5m2fnuz(unsigned long long int v) { return Float8E5M2FNUZ(narrow(v), Float8E5M2FNUZ::FromBits()); } -inline Float8E5M2FNUZ operator"" _f8e5m2fnuzp8(long double v) { - return Float8E5M2FNUZ(static_cast(v), true); -} +inline Float8E5M2FNUZ operator"" _f8e5m2fnuzp8(long double v) { return Float8E5M2FNUZ(static_cast(v), true); } #endif diff --git a/include/onnxruntime/core/framework/ort_value.h b/include/onnxruntime/core/framework/ort_value.h index 48c4e4320d..a071f3182f 100644 --- a/include/onnxruntime/core/framework/ort_value.h +++ b/include/onnxruntime/core/framework/ort_value.h @@ -68,11 +68,7 @@ struct OrtValue { } bool IsSparseTensor() const { -#if !defined(DISABLE_SPARSE_TENSORS) return (type_ != nullptr && type_->IsSparseTensorType()); -#else - ORT_THROW("Sparse tensor is not supported in this build."); -#endif } onnxruntime::MLDataType Type() const { diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 7e59aad80c..957942749e 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -24,6 +24,7 @@ constexpr const char* kMSDmlDomain = "com.microsoft.dml"; constexpr const char* kNGraphDomain = "com.intel.ai"; constexpr const char* kMIGraphXDomain = ""; constexpr const char* kVitisAIDomain = "com.xilinx"; +constexpr const char* kQuadricDomain = "com.quadric"; // This is moved from the OrtApis::GetAvailableProviders implementation // where it is enforced diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index bc7792ba43..456a11603d 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4333,8 +4333,12 @@ struct OrtApi { * \param[in] input_len Number of elements in the input_names and inputs arrays * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names * \param[in] output_names_len Number of elements in the output_names and outputs array - * \param[out] output Array of OrtValue* owned by customers, size to output_names_len. It could simply be an array of nullptr - * The array will be passed back to run_async_callback + * \param[out] output OrtValue* array of size output_names_len. + * On calling RunAsync, output[i] could either be a null or a pointer to a preallocated OrtValue. + * Later, the output array will be passed to run_async_callback with all null(s) filled with valid + * OrtValue pointer(s) allocated by onnxruntime. + * NOTE: it is customer's duty to finally release the output array and each of its member, + * regardless of whether the member (OrtValue*) is allocated by onnxruntime or preallocated by the customer. * \param[in] run_async_callback Callback function on model run completion * \param[in] user_data User data that pass back to run_async_callback */ diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index b9b6676c00..47356c3fe3 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1073,11 +1073,15 @@ struct SessionImpl : ConstSessionImpl { * * \param[in] run_options * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names - * \param[in] input_values Array of ::OrtValue%s of the input values + * \param[in] input_values Array of Value objects of length input_count * \param[in] input_count Number of elements in the input_names and inputs arrays * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names - * \param[out] output_values Array of ::OrtValue%s owned by customers, size to output_count. It could simply be an array of nullptr - * The array will be passed back to the callback + * \param[out] output_values Array of provided Values to be filled with outputs. + * On calling RunAsync, output_values[i] could either be initialized by a null pointer or a preallocated OrtValue*. + * Later, on invoking the callback, each output_values[i] of null will be filled with an OrtValue* allocated by onnxruntime. + * Then, an OrtValue** pointer will be casted from output_values, and pass to the callback. + * NOTE: it is customer's duty to finally release output_values and each of its member, + * regardless of whether the member (Ort::Value) is allocated by onnxruntime or preallocated by the customer. * \param[in] output_count Number of elements in the output_names and outputs array * \param[in] callback Callback function on model run completion * \param[in] user_data User data that pass back to the callback diff --git a/js/common/lib/version.ts b/js/common/lib/version.ts index 8f597765eb..3e303bcf64 100644 --- a/js/common/lib/version.ts +++ b/js/common/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.16.0'; +export const version = '1.16.2'; diff --git a/js/common/package-lock.json b/js/common/package-lock.json index b9e5fd6082..69cb6b60aa 100644 --- a/js/common/package-lock.json +++ b/js/common/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-common", - "version": "1.16.0", + "version": "1.16.2", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-common", - "version": "1.16.0", + "version": "1.16.2", "license": "MIT", "devDependencies": { "typedoc": "^0.23.22" diff --git a/js/common/package.json b/js/common/package.json index 331f17dbc4..06616c3247 100644 --- a/js/common/package.json +++ b/js/common/package.json @@ -2,7 +2,7 @@ "license": "MIT", "type": "module", "name": "onnxruntime-common", - "version": "1.16.0", + "version": "1.16.2", "repository": { "url": "https://github.com/Microsoft/onnxruntime.git", "type": "git" diff --git a/js/node/lib/version.ts b/js/node/lib/version.ts index 8f597765eb..3e303bcf64 100644 --- a/js/node/lib/version.ts +++ b/js/node/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.16.0'; +export const version = '1.16.2'; diff --git a/js/node/package-lock.json b/js/node/package-lock.json index bd01302262..6994f70a45 100644 --- a/js/node/package-lock.json +++ b/js/node/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-node", - "version": "1.16.0", + "version": "1.16.2", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-node", - "version": "1.16.0", + "version": "1.16.2", "license": "MIT", "os": [ "win32", @@ -27,7 +27,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.16.0", + "version": "1.16.2", "license": "MIT", "devDependencies": { "typedoc": "^0.23.22" diff --git a/js/node/package.json b/js/node/package.json index c898aeb56c..faa07d1149 100644 --- a/js/node/package.json +++ b/js/node/package.json @@ -13,7 +13,7 @@ 3 ] }, - "version": "1.16.0", + "version": "1.16.2", "dependencies": { "onnxruntime-common": "file:../common" }, diff --git a/js/node/src/inference_session_wrap.cc b/js/node/src/inference_session_wrap.cc index 78f32ec092..f8aeadbe27 100644 --- a/js/node/src/inference_session_wrap.cc +++ b/js/node/src/inference_session_wrap.cc @@ -68,7 +68,7 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo &info) { int64_t bytesOffset = info[1].As().Int64Value(); int64_t bytesLength = info[2].As().Int64Value(); - ParseSessionOptions(info[1].As(), sessionOptions); + ParseSessionOptions(info[3].As(), sessionOptions); this->session_.reset( new Ort::Session(OrtEnv(), reinterpret_cast(buffer) + bytesOffset, bytesLength, sessionOptions)); } else { diff --git a/js/package-lock.json b/js/package-lock.json index be7b3c9cd7..15bea7b23b 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -2532,8 +2532,8 @@ } }, "node_modules/fastq": { - "version": "1.15.0", - "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.15.0.tgz", + "version": "1.15.1", + "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.15.1.tgz", "integrity": "sha512-wBrocU2LCXXa+lWBt8RoIRD89Fi8OdABODa/kEnyeyjS5aZO5/GNvI5sEINADqP/h8M29UHTHUb53sUu5Ihqdw==", "dev": true, "dependencies": { @@ -8095,8 +8095,8 @@ "dev": true }, "fastq": { - "version": "1.15.0", - "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.15.0.tgz", + "version": "1.15.1", + "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.15.1.tgz", "integrity": "sha512-wBrocU2LCXXa+lWBt8RoIRD89Fi8OdABODa/kEnyeyjS5aZO5/GNvI5sEINADqP/h8M29UHTHUb53sUu5Ihqdw==", "dev": true, "requires": { diff --git a/js/react_native/lib/backend.ts b/js/react_native/lib/backend.ts index b3f0c46630..058531f415 100644 --- a/js/react_native/lib/backend.ts +++ b/js/react_native/lib/backend.ts @@ -66,12 +66,14 @@ class OnnxruntimeSessionHandler implements SessionHandler { let results: Binding.ModelLoadInfoType; // load a model if (typeof this.#pathOrBuffer === 'string') { + // load model from model path results = await this.#inferenceSession.loadModel(normalizePath(this.#pathOrBuffer), options); } else { + // load model from buffer if (!this.#inferenceSession.loadModelFromBlob) { throw new Error('Native module method "loadModelFromBlob" is not defined'); } - const modelBlob = jsiHelper.storeArrayBuffer(this.#pathOrBuffer); + const modelBlob = jsiHelper.storeArrayBuffer(this.#pathOrBuffer.buffer); results = await this.#inferenceSession.loadModelFromBlob(modelBlob, options); } // resolve promise if onnxruntime session is successfully created diff --git a/js/react_native/lib/version.ts b/js/react_native/lib/version.ts index 8f597765eb..3e303bcf64 100644 --- a/js/react_native/lib/version.ts +++ b/js/react_native/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.16.0'; +export const version = '1.16.2'; diff --git a/js/react_native/package.json b/js/react_native/package.json index 3020a04f0a..2c19037257 100644 --- a/js/react_native/package.json +++ b/js/react_native/package.json @@ -36,7 +36,7 @@ "registry": "https://registry.npmjs.org/" }, "source": "lib/index", - "version": "1.16.0", + "version": "1.16.2", "main": "dist/commonjs/index", "homepage": "https://github.com/microsoft/onnxruntime/blob/main/js/react_native/README.md", "files": [ diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock index 21734bc50b..ff2cfd2c8f 100644 --- a/js/react_native/yarn.lock +++ b/js/react_native/yarn.lock @@ -5188,7 +5188,7 @@ onetime@^5.1.0, onetime@^5.1.2: mimic-fn "^2.1.0" "onnxruntime-common@file:../common": - version "1.16.0" + version "1.16.2" open@^6.2.0: version "6.4.0" diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 4a1109b9ec..e33854819c 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -38,7 +38,7 @@ Do not modify directly.* | Floor | ai.onnx(6-12,13+) | | | Gather | ai.onnx(1-10,11-12,13+) | | | Gelu | com.microsoft(1+) | | -| Gemm | ai.onnx(7-8,9-10,11+) | | +| Gemm | ai.onnx(7-8,9-10,11-12,13+) | | | GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | | diff --git a/js/web/lib/version.ts b/js/web/lib/version.ts index 8f597765eb..3e303bcf64 100644 --- a/js/web/lib/version.ts +++ b/js/web/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.16.0'; +export const version = '1.16.2'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index 1d490aa902..82fe3d5b6a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -26,43 +26,41 @@ import {ConvTransposeAttributes} from '../conv-transpose'; const createConvTranspose2DOpProgramShaderSource = (shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: ConvTransposeAttributes, - outputShape: readonly number[], hasBias: boolean, elementsPerThread: readonly number[]): string => { + outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false): string => { const isChannelsLast = attributes.format === 'NHWC'; const rowDim = isChannelsLast ? 1 : 2; const colDim = isChannelsLast ? 2 : 3; const channelDim = isChannelsLast ? 3 : 1; const outputSize = ShapeUtil.size(outputShape); - const outChannels = outputShape[isChannelsLast ? 3 : 1]; - const inChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; - const isVec4 = inChannels % 4 === 0 && outChannels % 4 === 0; const workPerThread = isVec4 ? 2 : 1; + const group = attributes.group; + const wShape = inputs[1].dims; + const inputChannelsPerGroup = wShape[0] / group; + const outputChannelsPerGroup = wShape[1]; - const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : elementsPerThread[0]; - - const declareInputs = [ - `@group(0) @binding(0) var Dy: array<${ - isVec4 && innerElementSize === 4 ? 'vec4' : 'f32'}>;`, - `@group(0) @binding(1) var W: array<${isVec4 ? 'vec4' : 'f32'}>;` - ]; let declareFunctions = ` fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? 'vec4' : 'f32'}) { result[flatIndex] = ${isVec4 ? 'vec4' : 'f32'}(value); }`; if (hasBias) { - declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? 'vec4' : 'f32'}>;`); declareFunctions += ` fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } - const w = inputVariable('W', inputs[1].dataType, inputs[1].dims); - const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims); - const output = outputVariable('result', inputs[0].dataType, outputShape); + const components = isVec4 ? 4 : 1; + const w = inputVariable('W', inputs[1].dataType, inputs[1].dims, components); + const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims, components); + const inputVariables = [dy, w]; + if (hasBias) { + inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]], components)); + } + const output = outputVariable('result', inputs[0].dataType, outputShape, components); const codeSnippet4 = `{ - let batch: u32 = global_id.z / outShape[1]; - let r = global_id.z % outShape[1]; - let c = global_id.y * ${workPerThread}; - let d1: u32 = global_id.x * 4; + let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / outShape[1]; + let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % outShape[1]; + let c = ${is1DimensionDispatch ? 'global_id.y' : 'workgroup_id.y'} * ${workPerThread}; + let d1: u32 = ${is1DimensionDispatch ? 'global_id.x' : 'workgroup_id.x'} * 4; let dyCorner = vec2(i32(r), i32(c)) - vec2(pads); @@ -73,18 +71,21 @@ const createConvTranspose2DOpProgramShaderSource = dotProd[i] = vec4(0.0); } for (var wR: u32 = 0; wR < filterDims[0]; wR = wR + 1) { - var dyR = f32(dyCorner.x + wR) / f32(strides.x); - let wRPerm: u32= filterDims[0] - 1 - wR; + var dyR = (f32(dyCorner.x) + f32(wR)) / f32(strides.x); + let wRPerm = filterDims[0] - 1 - wR; if (dyR < 0.0 || dyR >= f32(outBackprop[1]) || - fract(dyR) > 0.0) { + fract(dyR) > 0.0 || wRPerm < 0) { continue; } let idyR: u32 = u32(dyR); for (var wC: u32 = 0; wC < filterDims[1]; wC = wC + 1) { - let dyC = f32(dyCorner.y + wC) / f32(strides.y); - let dyC2 = f32(dyCorner.y + 1 + wC) / f32(strides.y); - let wCPerm: u32 = filterDims[1] - 1 - wC; + let dyC = (f32(dyCorner.y) + f32(wC)) / f32(strides.y); + let dyC2 = (f32(dyCorner.y) + 1.0 + f32(wC)) / f32(strides.y); + let wCPerm = filterDims[1] - 1 - wC; + if (wCPerm < 0) { + continue; + } var bDyCVal = true; var bDyCVal2 = true; if (dyC < 0.0 || dyC >= f32(outBackprop[2]) || @@ -101,57 +102,53 @@ const createConvTranspose2DOpProgramShaderSource = if (bDyCVal && bDyCVal2) { let d2Length = outBackprop[3]; for (var d2 :u32 = 0; d2 < d2Length; d2 = d2 + 4) { - let wValue0 = ${w.get('d2', 'd1', 'wRPerm', 'wCPerm')}; - let wValue1 = ${w.get('d2', 'd1 + 1', 'wRPerm', 'wCPerm')}; - let wValue2 = ${w.get('d2', 'd1 + 2', 'wRPerm', 'wCPerm')}; - let wValue3 = ${w.get('d2', 'd1 + 3', 'wRPerm', 'wCPerm')}; + let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; + let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; + let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')}; + let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; - var xValue = ${ - isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'd2') : dy.get('batch', 'd2', 'idyR', 'idyC')}; - let tmpval = vec4(xValue * wValue0, - xValue * wValue1, - xValue * wValue2, - xValue * wValue3); + var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')}; + let tmpval = vec4(dot(xValue, wValue0), + dot(xValue, wValue1), + dot(xValue, wValue2), + dot(xValue, wValue3)); dotProd[0] = dotProd[0] + tmpval; - xValue = ${ - isChannelsLast ? dy.get('batch', 'idyR', 'idyC2', 'd2') : dy.get('batch', 'd2', 'idyR', 'idyC2')}; + xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')}; - dotProd[1] = dotProd[1] + vec4(xValue * wValue0, - xValue * wValue1, - xValue * wValue2, - xValue * wValue3); + dotProd[1] = dotProd[1] + vec4(dot(xValue, wValue0), + dot(xValue, wValue1), + dot(xValue, wValue2), + dot(xValue, wValue3)); } } else if (bDyCVal) { - let d2Length = outBackprop[3]; + let d2Length = outBackprop[${channelDim}]; for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { - let wValue0 = ${w.get('d2', 'd1', 'wRPerm', 'wCPerm')}; - let wValue1 = ${w.get('d2', 'd1 + 1', 'wRPerm', 'wCPerm')}; - let wValue2 = ${w.get('d2', 'd1 + 2', 'wRPerm', 'wCPerm')}; - let wValue3 = ${w.get('d2', 'd1 + 3', 'wRPerm', 'wCPerm')}; + let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; + let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; + let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')}; + let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; - var xValue = ${ - isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'd2') : dy.get('batch', 'd2', 'idyR', 'idyC')}; - let tmpval = vec4(xValue * wValue0, - xValue * wValue1, - xValue * wValue2, - xValue * wValue3); + var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')}; + let tmpval = vec4(dot(xValue, wValue0), + dot(xValue, wValue1), + dot(xValue, wValue2), + dot(xValue, wValue3)); dotProd[0] = dotProd[0] + tmpval; } } else if (bDyCVal2) { let d2Length = outBackprop[3]; for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { - let wValue0 = ${w.get('d2', 'd1', 'wRPerm', 'wCPerm')}; - let wValue1 = ${w.get('d2', 'd1 + 1', 'wRPerm', 'wCPerm')}; - let wValue2 = ${w.get('d2', 'd1 + 2', 'wRPerm', 'wCPerm')}; - let wValue3 = ${w.get('d2', 'd1 + 3', 'wRPerm', 'wCPerm')}; + let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; + let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; + let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')}; + let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; - var xValue = ${ - isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'd2') : dy.get('batch', 'd2', 'idyR', 'idyC')}; - let tmpval = vec4(xValue * wValue0, - xValue * wValue1, - xValue * wValue2, - xValue * wValue3); + var xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')}; + let tmpval = vec4(dot(xValue, wValue0), + dot(xValue, wValue1), + dot(xValue, wValue2), + dot(xValue, wValue3)); dotProd[1] = dotProd[1] + tmpval; } } @@ -159,16 +156,21 @@ const createConvTranspose2DOpProgramShaderSource = } for (var i: u32 = 0; i < ${workPerThread}; i = i + 1) { - ${output.set('batch', 'r', 'c+i', 'd1', 'dotProd[i]')}; + let value = dotProd[i] + ${hasBias ? 'bias[c+i]' : '0.0'}; + ${output.set('batch', 'r', 'c + i', 'd1', 'value')}; } }`; const codeSnippet = ` let outputIndices = ${output.offsetToIndices('global_idx')}; - let batch = outputIndices[0]; - let d1 = outputIndices[${channelDim}]; - let dyCorner = vec2(i32(outputIndices[${rowDim}]), i32(outputIndices[${colDim}])) - pads; + let batch = ${output.indicesGet('outputIndices', 0)}; + let d1 = ${output.indicesGet('outputIndices', channelDim)}; + let r = ${output.indicesGet('outputIndices', rowDim)}; + let c = ${output.indicesGet('outputIndices', colDim)}; + let dyCorner = vec2(i32(r), i32(c)) - pads; let dyRCorner = dyCorner.x; let dyCCorner = dyCorner.y; + let groupId = d1 / ${outputChannelsPerGroup}; + let wOutChannel = d1 - groupId * ${outputChannelsPerGroup}; // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). // ? = to be determined. : = across all values in that axis. var dotProd = 0.0; @@ -178,7 +180,7 @@ const createConvTranspose2DOpProgramShaderSource = } let dyR = (f32(dyRCorner) + f32(wR)) / f32(strides[0]); let wRPerm = filterDims.x - 1 - wR / dilations.x; - if (dyR < 0.0 || dyR >= f32(outBackprop[1]) || fract(dyR) > 0.0 || + if (dyR < 0.0 || dyR >= f32(outBackprop[${rowDim}]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } @@ -190,30 +192,29 @@ const createConvTranspose2DOpProgramShaderSource = } let dyC = (f32(dyCCorner) + f32(wC)) / f32(strides.y); let wCPerm = filterDims.y - 1 - wC / dilations.y; - if (dyC < 0.0 || dyC >= f32(outBackprop[2]) || + if (dyC < 0.0 || dyC >= f32(outBackprop[${colDim}]) || fract(dyC) > 0.0 || wCPerm < 0) { continue; } let idyC: u32 = u32(dyC); - for (var d2: u32 = 0; d2 < outBackprop[3]; d2 = d2 + 1) { + for (var d2: u32 = 0; d2 < ${inputChannelsPerGroup}; d2 = d2 + 1) { + let inputChannel = groupId * ${inputChannelsPerGroup} + d2; let xValue = ${ - isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'd2') : dy.get('batch', 'd2', 'idyR', 'idyC')}; - let wValue = ${w.get('d2', 'd1', 'wRPerm', 'wCPerm')}; + isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') : + dy.get('batch', 'inputChannel', 'idyR', 'idyC')}; + let wValue = ${w.get('inputChannel', 'wOutChannel', 'u32(wRPerm)', 'u32(wCPerm)')}; dotProd = dotProd + xValue * wValue; } } } - ${output.setByOffset('global_idx', 'dotProd')}; + let value = dotProd + ${hasBias ? 'bias[d1]' : '0.0'}; + ${output.setByOffset('global_idx', 'value')}; `; return ` - ${w.impl('indicesToOffset', 'get')} - ${dy.impl('indicesToOffset', 'get')} - ${output.impl('offsetToIndices')} + ${shaderHelper.declareVariables(...inputVariables, output)} ${declareFunctions} - ${declareInputs.join('\n')} - @group(0) @binding(${declareInputs.length}) var result: array<${isVec4 ? 'vec4' : 'f32'}>; const outShape : vec4 = vec4(${outputShape.join(',')}); const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); @@ -240,25 +241,18 @@ export const createConvTranspose2DProgramInfo = (inputs: readonly TensorView[], metadata: ProgramMetadata, attributes: ConvTransposeAttributes, squeezeOutputShapeFunction?: (shape: readonly number[]) => number[]): ProgramInfo => { const hasBias = inputs.length > 2; - const isChannelsLast = attributes.format === 'NHWC'; + // const isChannelsLast = attributes.format === 'NHWC'; const outputShape = attributes.outputShape; - const batchSize = outputShape[0]; - const outWidth = outputShape[isChannelsLast ? 1 : 2]; - const outHeight = outputShape[isChannelsLast ? 2 : 3]; - const outChannels = outputShape[isChannelsLast ? 3 : 1]; - const inChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; - const isVec4 = inChannels % 4 === 0 && outChannels % 4 === 0; + const outputSize = ShapeUtil.size(outputShape); - const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; - const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; - const workGroupSize: [number, number, number] = - isVec4 ? [8, 8, 1] : [dispatchX <= 4 ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1]; - const elementsPerThread = - isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 2, dispatchX > 4 && dispatchY <= 4 ? 1 : 2, 1]; + // const inChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; + // TODO Enable isVec4 for performance + // Disabled due to weight matrix layout issue + // const isVec4 = attributes.group === 1 && isChannelsLast && inChannels % 4 === 0 && outChannels % 4 === 0; const dispatch = [ - Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]), - Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]), - Math.ceil(batchSize / workGroupSize[2] / elementsPerThread[1]) + Math.ceil(outputSize / 64), + 1, + 1, ]; LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`); @@ -271,6 +265,6 @@ export const createConvTranspose2DProgramInfo = }], dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}), getShaderSource: (shaderHelper: ShaderHelper) => createConvTranspose2DOpProgramShaderSource( - shaderHelper, inputs, attributes, outputShape, hasBias, elementsPerThread), + shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index 5f3d156466..02b978a381 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -50,8 +50,6 @@ const createBinaryOpProgramShader = }; broadcastImpl = ` - ${output.impl('offsetToIndices')} - fn calcOffsetA(outputIndices: ${output.type.indices}) -> u32 { return ${calcOffsetImpl(dimsA)}; } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index e64c749725..7da57bcb9c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -16,28 +16,6 @@ import {ShapeUtil} from '../../util'; **/ export const WORKGROUP_SIZE = 64; -interface IndicesHelperImplementations { - /** - * implementation of `offsetToIndices` function. - */ - readonly offsetToIndices: string; - - /** - * implementation of `indicesToOffset` function. - */ - readonly indicesToOffset: string; - - /** - * implementation of `set`, `setByIndices` and `setByOffset` function. - */ - readonly set: string; - - /** - * implementation of `get`, `getByIndices` and `getByOffset` function. - */ - readonly get: string; -} - interface IndicesHelperTypes { /** * WGSL type of indices expression @@ -96,12 +74,10 @@ interface IndicesHelperTypes { */ export interface IndicesHelper { /** - * get WGSL code of function implementation for the util functions + * get WGSL code of function implementation for the util functions. * - * @param functions - a list of function names to get implementation for. If not specified, all functions will be - * returned. */ - readonly impl: (...functions: ReadonlyArray) => string; + readonly impl: () => string; /** * get type info @@ -215,9 +191,12 @@ export interface IndicesHelper { readonly shape: readonly number[]; } -const getWgslValueType = (type: number, components: 1|2|3|4): string|[string, string] => { +const getWgslMappedType = (type: number, components: 1|2|3|4): string|[string, string] => { // return type is [ storage type, runtime type ] or a single string for both switch (type) { + // TODO: enable after "shader-f16" WSGL extension release + // case DataType.float16: + // return components > 1 ? `vec${components}` : 'f16'; case DataType.float: return components > 1 ? `vec${components}` : 'f32'; case DataType.int32: @@ -245,6 +224,11 @@ const getWgslValueType = (type: number, components: 1|2|3|4): string|[string, st } }; +export const tensorTypeToWsglStorageType = (type: DataType, components: 1|2|3|4 = 1) => { + const mappedType = getWgslMappedType(type, components); + return typeof mappedType === 'string' ? mappedType : mappedType[0]; +}; + /** * A helper function to get a IndicesHelper for a given input or output. * @@ -260,13 +244,22 @@ const createIndicesHelper = components: 1|2|3|4): IndicesHelper => { const rank = shape.length; const indicesType = rank < 2 ? 'u32' : rank <= 4 ? `vec${rank}` : `array`; - const mappedType = getWgslValueType(tensorType, components); + const mappedType = getWgslMappedType(tensorType, components); const valueType = typeof mappedType === 'string' ? mappedType : mappedType[1]; const storageType = typeof mappedType === 'string' ? mappedType : mappedType[0]; const type = {indices: indicesType, value: valueType, storage: storageType, tensor: tensorType}; const normalizeDim = (dim: number|string): string => typeof dim === 'string' ? dim : `${dim}u`; + const implementationUsed = { + offsetToIndices: false, + indicesToOffset: false, + set: false, + setByIndices: false, + get: false, + getByIndices: false, + }; + const strides = ShapeUtil.computeStrides(shape); let o2iSnippet = ''; for (let i = 0; i < rank - 1; i++) { @@ -287,7 +280,10 @@ const createIndicesHelper = return indices; }`; - const offsetToIndices = (varOffset: string) => rank < 2 ? varOffset : `o2i_${name}(${varOffset})`; + const offsetToIndices = (varOffset: string) => { + implementationUsed.offsetToIndices = true; + return rank < 2 ? varOffset : `o2i_${name}(${varOffset})`; + }; const offsets: string[] = []; if (rank >= 2) { @@ -301,7 +297,10 @@ const createIndicesHelper = return ${offsets.join('+')}; }`; - const indicesToOffset = (varIndices: string) => rank < 2 ? varIndices : `i2o_${name}(${varIndices})`; + const indicesToOffset = (varIndices: string) => { + implementationUsed.indicesToOffset = true; + return rank < 2 ? varIndices : `i2o_${name}(${varIndices})`; + }; const indices = (...init: ReadonlyArray) => rank === 0 ? '0u' : `${type.indices}(${init.map(normalizeDim).join(',')})`; @@ -357,17 +356,18 @@ const createIndicesHelper = } })(); + const getByIndicesImplementation = rank < 2 ? '' : ` + fn get_${name}ByIndices(indices: ${type.indices}) -> ${valueType} { + return ${name}[i2o_${name}(indices)]; + }`; + const getImplementation = rank < 2 ? '' : (() => { const params = shape.map((_, i) => `d${i}: u32`).join(', '); const dims = shape.map((_, i) => `d${i}`).join(', '); return ` - fn get_${name}ByIndices(indices: ${type.indices}) -> ${valueType} { - return ${name}[i2o_${name}(indices)]; - } fn get_${name}(${params}) -> ${valueType} { return get_${name}ByIndices(${indices(dims)}); - } - `; + }`; })(); const get = (...indices: ReadonlyArray) => { @@ -376,14 +376,16 @@ const createIndicesHelper = } const normalizedIndices = indices.map(normalizeDim).join(','); - const funcName = `get_${name}`; if (rank === 0) { return getByOffset('0u'); } else if (rank === 1) { return getByOffset(normalizedIndices[0]); } else { - return `${funcName}(${normalizedIndices})`; + implementationUsed.get = true; + implementationUsed.getByIndices = true; + implementationUsed.indicesToOffset = true; + return `get_${name}(${normalizedIndices})`; } }; @@ -391,21 +393,24 @@ const createIndicesHelper = if (rank < 2) { return getByOffset(varIndices); } else { + implementationUsed.getByIndices = true; + implementationUsed.indicesToOffset = true; return `get_${name}ByIndices(${varIndices})`; } }; + const setByIndicesImplementation = rank < 2 ? '' : ` + fn set_${name}ByIndices(indices: ${type.indices}, value: ${valueType}) { + ${setByOffset(`i2o_${name}(indices)`, 'value')} + }`; + const setImplementation = rank < 2 ? '' : (() => { const params = shape.map((_, i) => `d${i}: u32`).join(', '); const dims = shape.map((_, i) => `d${i}`).join(', '); return ` - fn set_${name}ByIndices(indices: ${type.indices}, value: ${valueType}) { - ${setByOffset(`i2o_${name}(indices)`, 'value')} - } fn set_${name}(${params}, value: ${valueType}) { set_${name}ByIndices(${indices(dims)}, value); - } - `; + }`; })(); const set = (...indicesAndValue: ReadonlyArray) => { @@ -424,6 +429,9 @@ const createIndicesHelper = } else if (rank === 1) { return setByOffset(normalizedIndices[0], value); } else { + implementationUsed.set = true; + implementationUsed.setByIndices = true; + implementationUsed.indicesToOffset = true; return `set_${name}(${normalizedIndices}, ${value})`; } }; @@ -432,32 +440,34 @@ const createIndicesHelper = if (rank < 2) { return setByOffset(varIndices, value); } else { + implementationUsed.setByIndices = true; + implementationUsed.indicesToOffset = true; return `set_${name}ByIndices(${varIndices}, ${value});`; } }; - const funcImpls = { - offsetToIndices: offsetToIndicesImplementation, - indicesToOffset: indicesToOffsetImplementation, - set: setImplementation, - get: getImplementation, - }; - const impl = (...functions: Array) => { + const impl = () => { const impls = []; - if (functions.length === 0) { - functions.push('offsetToIndices', 'indicesToOffset', 'set', 'get'); + if (implementationUsed.offsetToIndices) { + impls.push(offsetToIndicesImplementation); + } + if (implementationUsed.indicesToOffset) { + impls.push(indicesToOffsetImplementation); + } + if (implementationUsed.set) { + impls.push(setImplementation); + } + if (implementationUsed.setByIndices) { + impls.push(setByIndicesImplementation); + } + if (implementationUsed.get) { + impls.push(getImplementation); } - for (const func of functions) { - const impl = funcImpls[func]; - if (impl === undefined) { - throw new Error(`unknown function ${func}`); - } else { - impls.push(impl); - } + if (implementationUsed.getByIndices) { + impls.push(getByIndicesImplementation); } return impls.join('\n'); }; - impl.toString = () => impl(); return { impl, @@ -552,6 +562,11 @@ export interface ShaderHelper { * @param variables - an array of IndicesHelper for the variables. */ declareVariables(...variables: IndicesHelper[]): string; + + /** + * Get additional implementation that needs to be added to the shader source. + */ + readonly additionalImplementations: string; } class ShaderHelperImpl implements ShaderHelper { @@ -585,6 +600,7 @@ class ShaderHelperImpl implements ShaderHelper { } declareVariable(variable: IndicesHelper, bindingIndex: number): string { + this.indicesHelpers.push(variable); const access = variable.usage === 'input' ? 'read' : 'read_write'; const storageType = variable.type.storage; return `@group(0) @binding(${bindingIndex}) var ${variable.name}: array<${storageType}>;`; @@ -594,6 +610,12 @@ class ShaderHelperImpl implements ShaderHelper { let i = 0; return variables.filter(v => ShapeUtil.size(v.shape) > 0).map(v => this.declareVariable(v, i++)).join('\n'); } + + private indicesHelpers: IndicesHelper[] = []; + + get additionalImplementations(): string { + return this.indicesHelpers.map(i => i.impl()).join('\n'); + } } export const createShaderHelper = (dispatchGroup: [number, number, number]): ShaderHelper => diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 8b91b64a09..9b294803d3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -109,9 +109,6 @@ const createConcatProgramInfo = const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.declareVariables(...inputVars, output)} - ${inputVars.map(i => i.impl('indicesToOffset', 'get')).join('\n')} - ${output.impl('offsetToIndices')} - const sizeInConcatAxis = array(${sizeInConcatAxis.map(i => `${i}u`).join(',')}); ${calculateInputIndexImpl(sizeInConcatAxis.length)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 7a0e1f01c4..8a794ce16a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -47,9 +47,6 @@ const createGroupedConvProgramInfo = ${shaderHelper.declareVariables(...inputVars, output)} ${activationFunction} - ${output.impl('offsetToIndices')} - ${x.impl('indicesToOffset', 'get')} - ${w.impl('indicesToOffset', 'get')} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index b07fe3a90f..2d845775f1 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -58,8 +58,6 @@ const createExpandProgramInfo = (metadata: ProgramMetadata, inputs: readonly Ten const getShaderSource = (shaderHelper: ShaderHelper) => ` const inputShape = ${input.indices(...inputShape)}; ${shaderHelper.declareVariables(input, output)} - ${output.impl('offsetToIndices')} - ${input.impl('indicesToOffset', 'get')} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} let outputIndices = ${output.offsetToIndices('global_idx')}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 2ce8427bb6..f62c766aa9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType, tensorTypeToWsglType} from '../../../wasm-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; -import {ShaderHelper} from './common'; +import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; export interface InstanceNormAttributes extends AttributeWithCacheKey { epsilon: number; @@ -45,7 +45,7 @@ const createInstanceNormProgramInfo = Got scale size of ${scaleSize} and bias size of ${biasSize}`); } - const dataType = tensorTypeToWsglType(inputs[0].dataType); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const getShaderSource = (shaderHelper: ShaderHelper) => ` const C: u32 = ${C}; @@ -99,7 +99,7 @@ const createInstanceNormNHWCProgramInfo = const C = xShape[xShape.length - 1]; const H = ShapeUtil.sizeFromDimension(xShape, 1) / C; - const dataType = tensorTypeToWsglType(inputs[0].dataType); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const normCount = C * N; const getShaderSource = (shaderHelper: ShaderHelper) => ` diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index 48627bfaec..8a9927b25a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType, tensorTypeToWsglType} from '../../../wasm-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; -import {ShaderHelper} from './common'; +import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; export interface LayerNormAttributes extends AttributeWithCacheKey { axis: number; @@ -54,7 +54,7 @@ const createLayerNormProgramInfo = } } - const dataType = tensorTypeToWsglType(inputs[0].dataType); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const hasMeanDataOutput = outputCount > 1; const hasInvStdOutput = outputCount > 2; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index 9af8fc7b6d..79071d3244 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -128,9 +128,6 @@ const generatePoolingCode = (${attributes.pads.map(i => `${i}u`).join(',')}); const inputDims = array(${inputDims.map(i => `${i}u`).join(',')}); const kernelStrides = array(${kernelStrides.map(i => `${i}u`).join(',')}); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts index b645510d83..cb592c838d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -85,9 +85,6 @@ export const createReduceProgramInfo = const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.declareVariables(input, output)} - ${output.impl('offsetToIndices')} - ${input.impl('indicesToOffset')} - ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} var inputIndices: ${input.type.indices}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index 505bae7ce2..1d0b8229a7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -484,8 +484,6 @@ const createResizeProgramInfo = } })()}; ${shaderHelper.declareVariables(input, output)} - ${output.impl('offsetToIndices')} - ${input.impl('indicesToOffset')} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} if (${noScale}) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index 96bf1cd9a6..4b845bcf21 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType, tensorTypeToWsglType} from '../../../wasm-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; -import {ShaderHelper} from './common'; +import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; export interface SkipLayerNormAttributes extends AttributeWithCacheKey { epsilon: number; @@ -84,7 +84,7 @@ const createSkipLayerNormProgramInfo = const meanInvStdDevDim = isTraining ? inputShape.slice(0, -1).concat(1) : []; const hasBetaInput = inputs.length > 3; const hasBiasInput = inputs.length > 4; - const dataType = tensorTypeToWsglType(inputs[0].dataType); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const hasMeanOutput = isTraining && outputCount > 1; const hasInvStdDevOutput = isTraining && outputCount > 2; const hasInputSkipBiasSumOutput = outputCount > 3; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index 1f881a75ff..4211e52689 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -153,8 +153,6 @@ const createSliceProgramInfo = const steps = array(${steps.map(i => `${i}u`).join(',')}); const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); - ${output.impl('offsetToIndices')} - ${input.impl('indicesToOffset', 'get')} ${calculateInputIndicesImpl(input, output, inputShape, outputShape)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 54f4934228..9a150d21ea 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -23,10 +23,12 @@ const validateInputs = (inputs: readonly TensorView[]): void => { const createSplitAttributesFromInputs = (inputs: readonly TensorView[], attributes: SplitAttributes): SplitAttributes => { const splitSizes: number[] = []; + let numOutputs: number = attributes.numOutputs; if (inputs[1].dims[0] > 0) { inputs[1].getBigInt64Array().forEach(v => splitSizes.push(Number(v))); + numOutputs = splitSizes.length; } - return createAttributeWithCacheKey({numOutputs: attributes.numOutputs, axis: attributes.axis, splitSizes}); + return createAttributeWithCacheKey({numOutputs, axis: attributes.axis, splitSizes}); }; const calculateOutputIndexImpl = (numberOfTensors: number): string => ` @@ -85,8 +87,6 @@ const createSplitProgramInfo = const indicesAxis = rank < 2 ? 'indices' : `indices[${adjustedAxis}]`; const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.declareVariables(input, ...outputs)} - ${input.impl('indicesToOffset', 'offsetToIndices', 'get')} - ${outputs.map(o => o.impl('indicesToOffset', 'set')).join('\n')} const sizeInConcatAxis = array(${sizeInConcatAxis.map(i => `${i}u`).join(',')}); ${calculateOutputIndexImpl(sizeInConcatAxis.length)} ${writeBufferDataImpl(outputs)} @@ -114,7 +114,7 @@ const createSplitProgramInfoLoader = const updatedAttributes = inputs.length === 1 ? attributes : createSplitAttributesFromInputs(inputs, attributes); const metadata: ProgramMetadata = {name: 'Split', inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey}; - return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], attributes)}; + return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], updatedAttributes)}; }; export const split = (context: ComputeContext, attributes: SplitAttributes): void => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts index 2b80ce1732..99d9668757 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts @@ -66,8 +66,6 @@ export const createTileProgramInfo = const getShaderSource = (shaderHelper: ShaderHelper) => ` const inputShape = ${input.indices(...inputShape)}; ${shaderHelper.declareVariables(input, output)} - ${output.impl('offsetToIndices')} - ${input.impl('indicesToOffset', 'get')} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} let outputIndices = ${output.offsetToIndices('global_idx')}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index 0b0185fc17..ebedc61712 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -64,8 +64,6 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu ${shaderHelper.declareVariables(input, output)} ${permFunctionBody(perm, rank, input, output)} - ${output.impl('offsetToIndices')} - ${input.impl('indicesToOffset', 'get')} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index b46b35b714..da710b7dc2 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -114,7 +114,9 @@ export class ProgramManager { build(programInfo: ProgramInfo, normalizedDispatchGroupSize: [number, number, number]): Artifact { const device = this.backend.device; - const code = programInfo.getShaderSource(createShaderHelper(normalizedDispatchGroupSize)); + const shaderHelper = createShaderHelper(normalizedDispatchGroupSize); + const userCode = programInfo.getShaderSource(shaderHelper); + const code = `${shaderHelper.additionalImplementations}\n${userCode}`; const shaderModule = device.createShaderModule({code}); LOG_DEBUG('verbose', () => `[WebGPU] shader code: ${code}`); diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index a89a585906..389773f3e8 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -164,19 +164,3 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro throw new Error(`unsupported logging level: ${logLevel}`); } }; - -export const tensorTypeToWsglType = (type: DataType) => { - switch (type) { - case DataType.float: - return 'f32'; - // TODO: enable after "shader-f16" WSGL extension release - // case DataType.float16: - // return 'f16'; - case DataType.int32: - return 'i32'; - case DataType.uint32: - return 'u32'; - default: - throw new Error(`Unsupported type: ${type}`); - } -}; diff --git a/js/web/package-lock.json b/js/web/package-lock.json index 4c5649d880..8ad55996f7 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-web", - "version": "1.16.0", + "version": "1.16.2", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-web", - "version": "1.16.0", + "version": "1.16.2", "license": "MIT", "dependencies": { "flatbuffers": "^1.12.0", @@ -49,7 +49,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.16.0", + "version": "1.16.2", "license": "MIT", "devDependencies": { "typedoc": "^0.23.22" diff --git a/js/web/package.json b/js/web/package.json index ce06475f67..76f793263e 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -8,7 +8,7 @@ "type": "git" }, "author": "fs-eire", - "version": "1.16.0", + "version": "1.16.2", "jsdelivr": "dist/ort.min.js", "dependencies": { "flatbuffers": "^1.12.0", diff --git a/js/web/test/data/ops/conv-transpose.jsonc b/js/web/test/data/ops/conv-transpose.jsonc new file mode 100644 index 0000000000..a249dc807f --- /dev/null +++ b/js/web/test/data/ops/conv-transpose.jsonc @@ -0,0 +1,289 @@ +[ + { + "name": "ConvTranspose without bias addition A", + "operator": "ConvTranspose", + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40], + "dims": [1, 1, 2, 2], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [10, 40, 40, 60, 200, 160, 90, 240, 160], + "dims": [1, 1, 3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ConvTranspose without bias addition B", + "operator": "ConvTranspose", + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80], + "dims": [1, 2, 2, 2], + "type": "float32" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + "dims": [2, 2, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 400, 940, 560, 1080, 2520, 1480, 760, 1740, 1000, 640, 1500, 880, 1720, 3960, 2280, 1160, 2620, 1480 + ], + "dims": [1, 2, 3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ConvTranspose with bias addition A", + "operator": "ConvTranspose", + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40], + "dims": [1, 4, 1, 1], + "type": "float32" + }, + { + "data": [ + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + ], + "dims": [4, 4, 2, 2], + "type": "float32" + }, + { + "data": [0.1, 0.2, 0.3, 0.4], + "dims": [4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 100.0999984741211, 100.0999984741211, 100.0999984741211, 100.0999984741211, 100.19999694824219, + 100.19999694824219, 100.19999694824219, 100.19999694824219, 100.30000305175781, 100.30000305175781, + 100.30000305175781, 100.30000305175781, 100.4000015258789, 100.4000015258789, 100.4000015258789, + 100.4000015258789 + ], + "dims": [1, 4, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ConvTranspose with bias addition B", + "operator": "ConvTranspose", + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [6, 8, 7, 9, 15, 11, 8, 12, 9], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 1, 1, 1], + "dims": [1, 1, 2, 2], + "type": "float32" + }, + { + "data": [5], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [11, 19, 20, 12, 20, 43, 46, 23, 22, 49, 52, 25, 13, 25, 26, 14], + "dims": [1, 1, 4, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ConvTranspose- group - A", + "operator": "ConvTranspose", + "attributes": [ + { "name": "kernel_shape", "data": [1, 1], "type": "ints" }, + { "name": "group", "data": 2, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0, 32.0, 34.0], + "dims": [1, 2, 3, 3], + "type": "float32" + }, + { + "data": [1.0, 2.0], + "dims": [2, 1, 1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 36, 40, 44, 48, 52, 56, 60, 64, 68], + "dims": [1, 2, 3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ConvTranspose- group - B", + "operator": "ConvTranspose", + "attributes": [ + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, + 19.0, 20.0, 21.0, 22.0, 23.0, 0, 0, 0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], + "dims": [3, 1, 2, 2], + "type": "float32" + }, + { + "data": [0.125, 0.25, 0.375], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.125, 1.125, 4.125, 4.125, 3.125, 13.125, 23.125, 18.125, 15.125, 43.125, 53.125, 36.125, 18.125, 45.125, + 52.125, 32.125, 45.25, 104.25, 115.25, 66.25, 123.25, 279.25, 305.25, 172.25, 159.25, 357.25, 383.25, + 214.25, 105.25, 232.25, 247.25, 136.25, 162.375, 351.375, 370.375, 200.375, 387.375, 833.375, 875.375, + 470.375, 231.375, 494.375, 517.375, 276.375, 0.375, 0.375, 0.375, 0.375 + ], + "dims": [1, 3, 4, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ConvTranspose- group - C", + "operator": "ConvTranspose", + "attributes": [ + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, + 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0 + ], + "dims": [1, 3, 3, 4], + "type": "float32" + }, + { + "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0, 1, 4, 7, 6, 4, 16, 26, 36, 26, 20, 56, 66, 76, 50, 24, 59, 66, 73, 44, 60, 137, 148, 159, 90, 164, 368, + 394, 420, 234, 212, 472, 498, 524, 290, 140, 307, 322, 337, 184, 216, 465, 484, 503, 270, 516, 1104, 1146, + 1188, 634, 596, 1272, 1314, 1356, 722, 352, 747, 770, 793, 420 + ], + "dims": [1, 3, 4, 5], + "type": "float32" + } + ] + } + ] + }, + + { + "name": "ConvTranspose- pointwise", + "operator": "ConvTranspose", + "attributes": [{ "name": "kernel_shape", "data": [1, 1], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + "dims": [1, 2, 2, 2], + "type": "float32" + }, + { + "data": [0.0, 1.0, 2.0, 3.0], + "dims": [2, 2, 1, 1], + "type": "float32" + }, + { + "data": [1, 2], + "dims": [2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [9, 11, 13, 15, 14, 18, 22, 26], + "dims": [1, 2, 2, 2], + "type": "float32" + } + ] + } + ] + } +] diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index d39d8edf0b..022451c885 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -7,7 +7,7 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime `_ or the `Github project `_. """ -__version__ = "1.16.0" +__version__ = "1.16.2" __author__ = "Microsoft" # we need to do device version validation (for example to check Cuda version for an onnxruntime-training package). diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index f1ab3e691b..b693b58c7c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -37,6 +37,7 @@ enum AttentionKernelType { AttentionKernel_TrtFlashAttention, AttentionKernel_TrtFusedCrossAttention, AttentionKernel_CutlassMemoryEfficientAttention, + AttentionKernel_FlashAttention, AttentionKernel_Default }; @@ -54,6 +55,7 @@ struct AttentionParameters { int v_hidden_size; // hidden size of V int v_head_size; // hidden size per head of V int num_heads; + int num_splits; bool is_unidirectional; bool past_present_share_buffer; bool do_rotary; @@ -81,6 +83,27 @@ struct PackedAttentionParameters { bool broadcast_res_pos_bias; }; +// Parameters deduced from node attributes and inputs/outputs. +struct GroupQueryAttentionParameters { + int batch_size; + int sequence_length; // sequence length of input query, key, value + int seqlen_past_kv_cache; // sequence length of past kv tensor + int seqlen_present_kv_cache; // sequence length of present kv tensor + int hidden_size; + int num_heads; + int head_size; + int kv_hidden_size; + int kv_num_heads; + int num_splits; // number of splits for splitkv + bool is_unidirectional; // causal + bool kv_share_buffer; + bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor + bool left_padding; // copies last token to last index if true + float scale; + AttentionQkvFormat qkv_format; + AttentionQkvFormat past_kv_format; +}; + namespace attention { // Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled). constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION"; @@ -98,8 +121,16 @@ constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTI // Environment variable to enable or disable cutlass memory efficient attention. Default is 0 (enabled). constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION"; +// Environment variable to enable or disable flash attention. Default is 0 (enabled). +constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION"; + // Minimum sequence length to enable memory efficient attention in FP32. -constexpr int kMinSequenceLengthForMemoryEfficientAttentionFp32 = 256; +constexpr int kMinSeqLenForMemoryEfficientAttentionFp32 = 256; + +// Minimum sequence length to prefer flash attention when input format is packed QKV for MultiHeadAttention +constexpr const char* kMinSeqLenForFlashAttentionPackedQKV = "ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV"; +// Default value for the above setting. +constexpr int kDefaultMinSeqLenForFlashAttentionPackedQKV = 513; } // namespace attention diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index 0b55cb7804..694c40bf3e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -16,7 +16,6 @@ #include #include -#include using onnxruntime::concurrency::ThreadPool; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index 1dc85e6d34..8eae3599a6 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -205,6 +205,7 @@ Status CheckInputs(const T* query, } } + int total_sequence_length = past_sequence_length + kv_sequence_length; AttentionMaskType mask_type = AttentionMaskType::MASK_NONE; if (key_padding_mask != nullptr) { mask_type = AttentionMaskType::MASK_UNKNOWN; @@ -215,13 +216,21 @@ Status CheckInputs(const T* query, } else if (mask_dims[0] == static_cast(3) * static_cast(batch_size) + static_cast(2)) { mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START; } - } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && mask_dims[1] == static_cast(kv_sequence_length)) { + } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(kv_sequence_length)) { + mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; + } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(total_sequence_length)) { mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; + } else if (mask_dims.size() == 3 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(sequence_length) && + mask_dims[2] == static_cast(total_sequence_length)) { + mask_type = AttentionMaskType::MASK_3D_ATTENTION; } if (mask_type == AttentionMaskType::MASK_UNKNOWN) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key_padding_mask' shape shall be (batch_size) or (batch_size, kv_sequence_length)"); + "Input 'key_padding_mask' shape shall be 1D, 2D, or 3D"); } } @@ -256,7 +265,6 @@ Status CheckInputs(const T* query, } } - int total_sequence_length = past_sequence_length + kv_sequence_length; bool broadcast_res_pos_bias = false; if (relative_position_bias != nullptr) { const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc new file mode 100644 index 0000000000..4a266af789 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/bert/rotary_embedding.h" +#include "contrib_ops/cpu/bert/rotary_embedding_helper.h" + +#include "core/platform/threadpool.h" + +using onnxruntime::concurrency::ThreadPool; +using namespace onnxruntime::contrib::rotary_embedding_helper; + +namespace onnxruntime { +namespace contrib { + +// These ops are internal-only, so register outside of onnx +ONNX_OPERATOR_TYPED_KERNEL_EX( + RotaryEmbedding, + kMSDomain, + 1, + float, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("M", DataTypeImpl::GetTensorType()), + RotaryEmbedding); + +template +RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { + scale = info.GetAttrOrDefault("scale", 1.0); + interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); +} + +template +Status RotaryEmbedding::Compute(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* position_ids = context->Input(1); + const Tensor* cos_cache = context->Input(2); + const Tensor* sin_cache = context->Input(3); + + RotaryParameters parameters = {}; + ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs(input, + position_ids, + cos_cache, + sin_cache, + ¶meters)); + + Tensor* output = context->Output(0, input->Shape()); + + if (parameters.sequence_length > parameters.max_sequence_length) { + // Launch update_cos_sin_cache kernel with scale + ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); + } + + const T* input_src = input->Data(); + const int64_t* pos_ids_data = position_ids->Data(); + const T* cos_cache_data = cos_cache->Data(); + const T* sin_cache_data = sin_cache->Data(); + T* output_dest = output->MutableData(); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int head_size = parameters.head_size; + const int position_ids_format = parameters.position_ids_format; + const int half_head_size = head_size / 2; + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + auto* tp = context->GetOperatorThreadPool(); + + const int loop_len = batch_size * sequence_length * num_heads; + const double cost = static_cast(head_size); + ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { + const int b = static_cast((ptr / num_heads) / sequence_length); + const int s = static_cast((ptr / num_heads) % sequence_length); + const int n = static_cast(ptr % num_heads); + + const int block_offset = b * sequence_length * num_heads + s * num_heads + n; + const int data_offset = block_offset * head_size; + + const T* input_data = input_src + data_offset; + T* output_data = output_dest + data_offset; + + // Cache is (M, H/2) + const int position_id = (position_ids_format == 0) + ? static_cast(pos_ids_data[0]) + s + : static_cast(pos_ids_data[b * sequence_length + s]); + const int cache_offset = position_id * half_head_size; + const T* cos_data = cos_cache_data + cache_offset; + const T* sin_data = sin_cache_data + cache_offset; + + int cache_idx = 0; + T sign = 0; + int j = 0; + for (int i = 0; i < head_size; i++) { + if (interleaved) { + cache_idx = (i / 2) % half_head_size; + sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1); + j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign + } else { + cache_idx = i % half_head_size; + sign = (i < half_head_size) ? static_cast(-1) : static_cast(1); + j = (i + half_head_size) % head_size; + } + output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; + } + } + }); + + return Status::OK(); +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h new file mode 100644 index 0000000000..be834a66cd --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace contrib { + +template +class RotaryEmbedding final : public OpKernel { + public: + RotaryEmbedding(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; + + protected: + float scale; + bool interleaved; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h new file mode 100644 index 0000000000..cf8080800e --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/common.h" + +namespace onnxruntime { +namespace contrib { +namespace rotary_embedding_helper { + +// Parameters deduced from node attributes and inputs/outputs. +struct RotaryParameters { + int batch_size; // Batch size used by input + int sequence_length; // Sequence length used by input + int hidden_size; // Hidden size used by input + int head_size; // Head size used by cos/sin cache * 2 + int num_heads; // num_heads = hidden_size / head_size + int max_sequence_length; // Sequence length used by cos/sin cache + int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) +}; + +template +Status CheckInputs(const T* input, + const T* position_ids, + const T* cos_cache, + const T* sin_cache, + void* parameters) { + // input : (batch_size, sequence_length, hidden_size) + // position ids : (1) or (batch_size, sequence_length) + // cos cache : (max_sequence_length, head_size / 2) + // sin cache : (max_sequence_length, head_size / 2) + + // Check input + const auto& input_dims = input->Shape().GetDims(); + if (input_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 dimensions, got ", + input_dims.size()); + } + // Check position_ids + const auto& position_ids_dims = position_ids->Shape().GetDims(); + if (!onnxruntime::IsScalarOr1ElementVector(position_ids) && position_ids_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' is expected to have 0, 1, or 2 ", + "dimensions, got ", position_ids_dims.size()); + } + // Check cos_cache and sin_cache + const auto& cos_cache_dims = cos_cache->Shape().GetDims(); + if (cos_cache_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' is expected to have 2 dimensions, got ", + cos_cache_dims.size()); + } + const auto& sin_cache_dims = sin_cache->Shape().GetDims(); + if (sin_cache_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' is expected to have 2 dimensions, got ", + sin_cache_dims.size()); + } + if (cos_cache_dims[0] != sin_cache_dims[0] || cos_cache_dims[1] != sin_cache_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have ", + "the same shape"); + } + + // Get attributes from inputs + int batch_size = static_cast(input_dims[0]); + int sequence_length = static_cast(input_dims[1]); + int hidden_size = static_cast(input_dims[2]); + int max_sequence_length = static_cast(cos_cache_dims[0]); + int head_size = static_cast(cos_cache_dims[1]) * 2; + int num_heads = hidden_size / head_size; + int position_ids_format = -1; + + // Check position_ids input shapes + if (!onnxruntime::IsScalarOr1ElementVector(position_ids)) { + if (batch_size != static_cast(position_ids_dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 0 should be of size ", + "batch_size, got ", position_ids_dims[0]); + } + if (sequence_length != static_cast(position_ids_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 1 should be of size ", + "sequence_length, got ", position_ids_dims[1]); + } + position_ids_format = 1; + } else { + position_ids_format = 0; + } + // Check cos_cache input shapes + if (max_sequence_length != static_cast(cos_cache_dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ", + "max_sequence_length, got ", cos_cache_dims[0]); + } + if ((head_size / 2) != static_cast(cos_cache_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ", + "head_size / 2, got ", cos_cache_dims[1]); + } + // Check sin_cache input shapes + if (max_sequence_length != static_cast(sin_cache_dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 0 should be same as ", + "max_sequence_length, got ", sin_cache_dims[0]); + } + if ((head_size / 2) != static_cast(sin_cache_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 1 should be same as ", + "head_size / 2, got ", sin_cache_dims[1]); + } + + // Set rotary parameters + if (parameters != nullptr) { + RotaryParameters* output_parameters = reinterpret_cast(parameters); + output_parameters->batch_size = batch_size; + output_parameters->sequence_length = sequence_length; + output_parameters->hidden_size = hidden_size; + output_parameters->head_size = head_size; + output_parameters->num_heads = num_heads; + output_parameters->max_sequence_length = max_sequence_length; + output_parameters->position_ids_format = position_ids_format; + } + + return Status::OK(); +} + +} // namespace rotary_embedding_helper +} // namespace contrib +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index d6d844e245..49fdf2b332 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -19,6 +19,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer); @@ -27,6 +28,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordC class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4); #ifndef ORT_MINIMAL_BUILD class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4); #endif @@ -51,6 +54,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Quick // ******** Start: Quantization ******************* // class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearGlobalAveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearConvTranspose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearConcat); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearWhere); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearAveragePool); @@ -117,6 +121,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipSimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu); @@ -185,6 +191,7 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -240,6 +247,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -252,6 +260,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, // backward compatibility BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifndef ORT_MINIMAL_BUILD BuildKernelCreateInfo, #endif @@ -285,6 +295,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/quadric/quadric_custom_op.cc b/onnxruntime/contrib_ops/cpu/quadric/quadric_custom_op.cc new file mode 100644 index 0000000000..e333205d37 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quadric/quadric_custom_op.cc @@ -0,0 +1,283 @@ +// Copyright (c) Quadric, Inc. All rights reserved. +// Licensed under the MIT License. + +#include "quadric_custom_op.h" +#include "core/common/common.h" +#include "core/framework/op_kernel_context_internal.h" +#include "core/framework/ortdevice.h" +#include "core/framework/ortmemoryinfo.h" +#include "core/framework/session_options.h" +#include "core/framework/session_state.h" +#include "core/framework/tensorprotoutils.h" +#include "core/framework/utils.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; + +namespace onnxruntime { + +ONNX_OPERATOR_KERNEL_EX(QuadricCustomOp, kQuadricDomain, 1, kCpuExecutionProvider, KernelDefBuilder(), QuadricCustomOp); + +QuadricCustomOp::Info::Info(const onnxruntime::Node& node, const GraphViewer& subgraph_in) : subgraph(subgraph_in), used_inputs(node.InputDefs().size(), false) { + num_inputs = static_cast(node.InputDefs().size()); + num_outputs = static_cast(node.OutputDefs().size()); + + auto& subgraph_inputs = subgraph.GetInputs(); + auto num_subgraph_inputs = subgraph_inputs.size(); + + for (size_t i = 0; i < num_subgraph_inputs; ++i) { + auto& input = subgraph_inputs[i]; + subgraph_input_names.insert(input->Name()); + } + + // This is only an inequality because we include initializers as inputs to the custom op, but + // *NOT* the sub-graph. As a result, the number of inputs differs. Unfortunately, ORT doesn't do + // a great job of telling us whether something is truly an initializer or not, so we can't + // effectively check whether an input is an initializer or not. + ORT_ENFORCE(num_subgraph_inputs <= static_cast(num_inputs), + "'QuadricCustomOp' node (", node.Name(), ") has ", num_inputs, " inputs which is fewer than the subgraph's ", + num_subgraph_inputs, " inputs."); + + auto& subgraph_outputs = subgraph.GetOutputs(); + auto num_subgraph_outputs = subgraph_outputs.size(); + + // outputs should always match up, so enforce that. + ORT_ENFORCE(num_subgraph_outputs == static_cast(num_outputs), + "'QuadricCustomOp' node has ", num_outputs, " outputs which doesn't match the subgraph's ", + num_subgraph_outputs, " outputs."); + + subgraph_output_names.reserve(num_subgraph_outputs); + for (size_t i = 0; i < num_subgraph_outputs; ++i) { + auto& output = subgraph_outputs[i]; + subgraph_output_names.push_back(output->Name()); + } +} + +class QuadricCustomOpImpl { + public: + QuadricCustomOpImpl(OpKernelContextInternal& context, + const SessionState& session_state, + const QuadricCustomOp::Info& info); + + Status Initialize(); + Status Execute(const FeedsFetchesManager& ffm); + + private: + OpKernelContextInternal& context_; + const SessionState& session_state_; + const QuadricCustomOp::Info& info_; + + Status AllocateOutputTensors(); + + enum class AllocationType { + Delayed, // allocation of If output will be done by subgraph execution + SubgraphOutput + }; + + // track where the fetches provided to subgraph execution were allocated. + std::vector> outputs_; +}; + +QuadricCustomOpImpl::QuadricCustomOpImpl(OpKernelContextInternal& context, + const SessionState& session_state, + const QuadricCustomOp::Info& info) : context_(context), + session_state_(session_state), + info_(info) {} + +Status QuadricCustomOpImpl::Initialize() { + auto status = AllocateOutputTensors(); + ORT_RETURN_IF_ERROR(status); + + return Status::OK(); +} + +Status QuadricCustomOpImpl::AllocateOutputTensors() { + // This function mostly copied from if.cc + Status status = Status::OK(); + int index = 0; + + const GraphViewer& subgraph = session_state_.GetGraphViewer(); + + const auto& graph_outputs = subgraph.GetOutputs(); + + for (auto& graph_output : graph_outputs) { + const auto* graph_output_type = graph_output->TypeAsProto(); + + ORT_ENFORCE(graph_output_type->has_tensor_type() || graph_output_type->has_sequence_type(), "Only tensors or tensor sequences are supported"); + if (graph_output_type->has_tensor_type()) { + auto* graph_output_shape = graph_output->Shape(); + bool symbolic_dim_in_shape = false; + + if (graph_output_shape) { + TensorShape output_shape = onnxruntime::utils::GetTensorShapeFromTensorShapeProto(*graph_output_shape); + + // if size < 0 we have a symbolic dimension and need to use a temporary OrtValue in the subgraph execution + if (output_shape.Size() < 0) { + symbolic_dim_in_shape = true; + } else { + auto* tensor = context_.Output(index, output_shape); + + if (!tensor) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create output tensor for ", graph_output->Name()); + + outputs_.push_back({AllocationType::SubgraphOutput, *context_.GetOutputMLValue(index)}); + } + } + + if (!graph_output_shape || symbolic_dim_in_shape) { + // we still need a value to put in the feeds we give to the execution frame, so just use an empty MLValue + outputs_.push_back({AllocationType::Delayed, {}}); + } + } else if (graph_output_type->has_sequence_type()) { + auto* seq_tensor = context_.Output(index); + if (!seq_tensor) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create output tensor for ", graph_output->Name()); + outputs_.push_back({AllocationType::SubgraphOutput, *context_.GetOutputMLValue(index)}); + } + ++index; + } + + return Status::OK(); +} + +Status QuadricCustomOpImpl::Execute(const FeedsFetchesManager& ffm) { + Status status = Status::OK(); + + auto num_inputs = context_.InputCount(); + std::vector feeds; + feeds.reserve(num_inputs); + + // This will contain used inputs, so some/all initializers may not be present + for (int i = 0; i < num_inputs; ++i) { + if(info_.used_inputs[i]) { + feeds.push_back(*context_.GetInputMLValue(i)); + } + } + + std::vector fetches; + std::unordered_map fetch_allocators; + + fetches.reserve(info_.num_outputs); + for (int i = 0; i < info_.num_outputs; ++i) { + fetches.push_back(outputs_[i].second); + + if (outputs_[i].first == AllocationType::Delayed) { + // functor to forward the allocation request from the subgraph to the If node's context so that the + // allocation plan for the If node's output is used. + fetch_allocators[i] = [this, i, &fetches](const TensorShape& shape, const OrtDevice& location, + OrtValue& ort_value, bool& allocated) { + // if the device the QuadricCustomOp output is allocated on does not match the required device for the subgraph output + // we don't update the provided OrtValue and return false for 'allocated'. + // the execution frame will allocate a buffer on the required device, and the fetches copy + // logic in utils::ExecuteSubgraph will handle moving it into the tensor we allocated here. + + auto* tensor = context_.Output(i, shape); + if (!tensor) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create output tensor for QuadricCustomOp output ", i); + + const OrtValue& value = *context_.GetOutputMLValue(i); + + if (tensor->Location().device == location) { + // return OrtValue for allocated tensor + ort_value = value; + allocated = true; + } else { + // put the allocated value into fetches so the copy logic in utils::ExecuteGraphImpl can use it + fetches[i] = value; + } + + return Status::OK(); + }; + } + } + + status = utils::ExecuteSubgraph(session_state_, ffm, feeds, fetches, fetch_allocators, + ExecutionMode::ORT_SEQUENTIAL, context_.GetTerminateFlag(), + context_.Logger(), context_.GetComputeStream()); + + ORT_RETURN_IF_ERROR(status); + + return status; +} + +QuadricCustomOp::QuadricCustomOp(const OpKernelInfo& info) : IControlFlowKernel(info) { + ONNX_NAMESPACE::GraphProto proto; + ORT_ENFORCE(info.GetAttr("sub_graph", &proto).IsOK()); + ORT_IGNORE_RETURN_VALUE(proto); +} + +Status QuadricCustomOp::Compute(OpKernelContext* ctx) const { + auto ctx_internal = static_cast(ctx); + auto* session_state = ctx_internal->SubgraphSessionState("sub_graph"); + ORT_ENFORCE(session_state, "Subgraph SessionState was not found for sub_graph attribute."); + + QuadricCustomOpImpl impl{*ctx_internal, *session_state, *info_}; + auto status = impl.Initialize(); + ORT_RETURN_IF_ERROR(status); + + status = impl.Execute(*feeds_fetches_manager_); + + return Status::OK(); +} + +Status QuadricCustomOp::SetupSubgraphExecutionInfo(const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) { + const auto& node = Node(); + info_ = std::make_unique(node, subgraph_session_state.GetGraphViewer()); + + const auto& subgraph_map = subgraph_session_state.GetOrtValueNameIdxMap(); + auto num_subgraph_inputs = subgraph_session_state.GetGraphViewer().GetInputs().size(); + + std::vector feed_names; + + const auto& input_defs = node.InputDefs(); + for (size_t i = 0, end = num_subgraph_inputs; i < end; ++i) { + const auto* input = input_defs[i]; + // Not all subgraph inputs will have names that correspond to the node's inputs. The inputs + // that diverge like this are limited *only* to initializers and we don't need to create + // feeds for them. Furthermore, since they are not actually used by the custom op (and + // not even by the sub-graph since the subgraph contains its own version of initializers) + // they end up getting removed from the graph during an optimization step and so we can't + // prove that it's an initializer using Graph::IsInitializedTensor + std::string input_name = input->Name(); + // Strip-off any '/duplicated' + auto pos = input_name.find("/duplicated"); + if (pos != std::string::npos) { + input_name = input_name.erase(pos); + } + ORT_ENFORCE(info_->subgraph_input_names.find(input_name) != info_->subgraph_input_names.end(), + "Could not match input ", input_name, " with any subgraph input."); + feed_names.push_back(input_name); + info_->used_inputs[i] = true; + } + + std::unique_ptr ffm; + ORT_RETURN_IF_ERROR(FeedsFetchesManager::Create(feed_names, info_->subgraph_output_names, subgraph_map, ffm)); + ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(subgraph_session_state, *ffm)); + + // find the location all the feeds will be coming from + std::vector feed_locations; + feed_locations.resize(feed_names.size()); + for (size_t i = 0, end = feed_names.size(); i < end; ++i) { + const auto& location = utils::FindDeviceForValue(session_state, feed_names[i]); + feed_locations[i] = location; + } + + std::vector fetch_locations; + fetch_locations.reserve(info_->num_outputs); + + // we need the allocator info for each output from the QuadricCustomOp node + // as the subgraph execution will write directly into those buffers + const auto& outputs = node.OutputDefs(); + for (int i = 0, end = info_->num_outputs; i < end; ++i) { + const auto& alloc_info = utils::FindDeviceForValue(session_state, outputs[i]->Name()); + fetch_locations.push_back(&alloc_info); + } + + utils::FinalizeFeedFetchCopyInfo(*ffm, feed_locations, fetch_locations); + + feeds_fetches_manager_ = std::move(ffm); + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quadric/quadric_custom_op.h b/onnxruntime/contrib_ops/cpu/quadric/quadric_custom_op.h new file mode 100644 index 0000000000..4e840f81b6 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quadric/quadric_custom_op.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cpu/controlflow/utils.h" + +#include "core/framework/feeds_fetches_manager.h" +#include "core/framework/op_kernel.h" +#include "core/framework/op_kernel_context_internal.h" +#include "core/session/onnxruntime_cxx_api.h" +#include +#include + +namespace onnxruntime { + +struct QuadricCustomOp : public controlflow::IControlFlowKernel { + QuadricCustomOp(const OpKernelInfo& info); + + Status Compute(OpKernelContext* ctx) const override; + + virtual Status SetupSubgraphExecutionInfo(const SessionState& session_state, + const std::string& attribute_name, + const SessionState& subgraph_session_state) override; + + struct Info { + Info(const onnxruntime::Node& node, const GraphViewer& subgraph_in); + const GraphViewer& subgraph; + + int num_inputs; + int num_outputs; + + std::unordered_set subgraph_input_names; + std::vector used_inputs; + std::vector subgraph_output_names; + }; + + private: + std::unique_ptr info_; + std::unique_ptr feeds_fetches_manager_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block.h b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block.h new file mode 100644 index 0000000000..11b5447d65 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block.h @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +#if defined(_MSC_VER) +#define FORCEINLINE __forceinline +#else +#define FORCEINLINE __attribute__((always_inline)) inline +#endif + +template +struct alignas(1) BlockwiseQuantBlock { + static_assert(block_size % 8 == 0); + + uint8_t blob_data[block_size / 8 * bits]; + + FORCEINLINE void dequant(T* dst, T scale, int32_t k_idx, int32_t K) const; + FORCEINLINE void dequant(T* dst, T scale, uint8_t zp, int32_t k_idx, int32_t K) const; + + FORCEINLINE void quant(const T* src, T& scale, int32_t k_idx, int32_t K, int32_t N); + FORCEINLINE void quant(const T* src, T& scale, uint8_t& zp, int32_t k_idx, int32_t K, int32_t N); +}; + +template +struct alignas(1) BlockwiseQuantBlock { + static_assert(block_size % 8 == 0); + + uint8_t blob_data[block_size / 2]; + + FORCEINLINE void dequant(T* dst, T scale, uint8_t zp, int32_t k_idx, int32_t K) const { + for (int i = 0; i < block_size; i += 2) { + T zp_t = static_cast(float(zp)); + if (k_idx + i < K) { + T x0 = static_cast(float(blob_data[i / 2] & 0xF)); + dst[i] = scale * (x0 - zp_t); + } + if (k_idx + i + 1 < K) { + T x1 = static_cast(float(blob_data[i / 2] >> 4)); + dst[i + 1] = scale * (x1 - zp_t); + } + } + } + + FORCEINLINE void dequant(T* dst, T scale, int32_t k_idx, int32_t K) const { + constexpr uint8_t zp = 8; + dequant(dst, scale, zp, k_idx, K); + } + + FORCEINLINE void quant(const T* src, T& scale_block, uint8_t& zp, int32_t k_idx, int32_t K, int32_t N) { + float min = static_cast(*src); + float max = static_cast(*src); + int32_t klen = std::min(block_size, K - k_idx); + for (int32_t kk = 0; kk < klen; kk++) { + const float v = static_cast(src[N * kk]); + if (v < min) min = v; + if (v > max) max = v; + } + min = std::min(min, 0.0f); + max = std::max(max, 0.0f); + + const float scale = (max - min) / ((1 << 4) - 1); + scale_block = static_cast(scale); + + const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; + float zero_point_fp = min; + if (scale != 0.0f) { + zero_point_fp = 0.f - min / scale; + } + + // Handle any clamping + if (zero_point_fp < 0.0f) { + zp = 0; + } else if (zero_point_fp > 15.0f) { + zp = 15; + } else { + zp = (uint8_t)roundf(zero_point_fp); + } + + for (int32_t kk = 0; kk < klen; kk += 2) { + const float v0 = static_cast(src[N * kk]); + const uint8_t vi0 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v0 * reciprocal_scale + zp))); + + const float v1 = static_cast((kk + 1 < klen) ? src[N * (kk + 1)] : 0.f); + const uint8_t vi1 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v1 * reciprocal_scale + zp))); + + blob_data[kk / 2] = vi0 | (vi1 << 4); + } + } + + FORCEINLINE void quant(const T* src, T& scale_block, int32_t k_idx, int32_t K, int32_t N) { + float amax = 0.0f; // abs(max) + float max = 0.0f; + + int32_t klen = std::min(block_size, K - k_idx); + + for (int32_t kk = 0; kk < klen; kk++) { + const float v = static_cast(src[N * kk]); + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float scale = max / (-8.f); + scale_block = static_cast(scale); + const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; + + for (int32_t kk = 0; kk < klen; kk += 2) { + const float v0 = src[N * kk] * reciprocal_scale; + const uint8_t vi0 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v0 + 8.f))); + + const float v1 = (kk + 1 < klen) ? src[N * (kk + 1)] * reciprocal_scale : 0; + const uint8_t vi1 = (uint8_t)std::min(15.0f, std::max(0.0f, roundf(v1 + 8.f))); + + blob_data[kk / 2] = vi0 | (vi1 << 4); + } + } +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h new file mode 100644 index 0000000000..cb8e97a592 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h @@ -0,0 +1,202 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +#if defined(_MSC_VER) +#define FORCEINLINE __forceinline +#else +#define FORCEINLINE __attribute__((always_inline)) inline +#endif + +typedef enum Bnb_DataType_t { + FP4 = 0, + NF4 = 1, +} Bnb_DataType_t; + +FORCEINLINE uint8_t QuantizeOneFP4(float x) { + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + + // we do a binary search + // the pivots are divided by 12 (the FP4 absmax) + // since we assum input data is in [-1.0, 1.0] + + // !be careful here, its easy to make a mistake + // that is difficult to noice if you add an extra + // zero somewhere! + + uint8_t sign = x < 0 ? 0b1000 : 0b0000; + x = fabsf(x); + if (x > 0.29166667f) { + if (x > 0.583333f) { + if (x > 0.8333333f) { + return 0b0011 + sign; + } else { + return 0b0010 + sign; + } + } else if (x > 0.4166667f) { + return 0b101 + sign; + } else { + return 0b100 + sign; + } + } else if (x > 0.0859375f) { + if (x > 0.20833333f) { + return 0b0111 + sign; + } else { + return 0b0110 + sign; + } + } else if (x > 0.00260417f) { + return 0b0001 + sign; + } else { + return 0b0000 + sign; + } +} + +FORCEINLINE uint8_t QuantizeOneNF4(float x) { + if (x > 0.03979014977812767f) { + if (x > 0.3893125355243683f) { // 1 + if (x > 0.6427869200706482f) { // 11 + if (x > 0.8614784181118011f) { // 111 + return 0b1111; + } else { + return 0b1110; + } + } else if (x > 0.5016634166240692f) { // 110 + return 0b1101; + } else { + return 0b1100; + } + } else if (x > 0.2035212516784668f) { // 10 + if (x > 0.2920137718319893f) { // 101 + return 0b1011; + } else { + return 0b1010; + } + } else if (x > 0.1202552504837513f) { // 100 + return 0b1001; + } else { + return 0b1000; + } + } else if (x > -0.33967943489551544f) { // 0 + if (x > -0.13791173323988914f) { // 01 + if (x > -0.045525018125772476f) { // 011 + return 0b0111; + } else { + return 0b0110; + } + } else if (x > -0.23460740596055984f) { // 010 + return 0b0101; + } else { + return 0b0100; + } + } else if (x > -0.6106329262256622f) { // 00 + if (x > -0.4599952697753906f) { // 001 + return 0b0011; + } else { + return 0b0010; + } + } else if (x > -0.8480964004993439f) { // 000 + return 0b0001; + } else { + return 0b0000; + } +} + +template +FORCEINLINE uint8_t QuantizeOneBnb4(float x) { + if constexpr (DATA_TYPE == FP4) + return QuantizeOneFP4(x); + else + return QuantizeOneNF4(x); +} + +template +FORCEINLINE void QuantizeBlockBnb4(const T* src, uint8_t* dst, T& absmax_block, int32_t block_idx, int32_t numel) { + float local_absmax = 0.0f; + + int32_t block_len = std::min(block_size, numel - block_idx * block_size); + int32_t src_offset = block_idx * block_size; + int32_t dst_offset = block_idx * block_size / 2; + + for (int32_t idx = 0; idx < block_len; idx++) { + const float v = static_cast(src[src_offset + idx]); + local_absmax = fmaxf(local_absmax, fabsf(v)); + } + + absmax_block = static_cast(local_absmax); + const float reciprocal_absmax = local_absmax ? 1.0f / local_absmax : 0.0f; + + for (int32_t idx = 0; idx < block_len; idx += 2) { + const float v0 = static_cast(src[src_offset + idx]) * reciprocal_absmax; + const uint8_t vi0 = QuantizeOneBnb4(v0); + + const float v1 = (idx + 1 < block_len) ? static_cast(src[src_offset + idx + 1]) * reciprocal_absmax : 0; + const uint8_t vi1 = QuantizeOneBnb4(v1); + + dst[dst_offset + idx / 2] = (vi0 << 4) | vi1; + } +} + +static float fp4_qaunt_map[16] = {0.00000000f, 5.208333333e-03f, 0.66666667f, 1.00000000f, + 0.33333333f, 0.50000000f, 0.16666667f, 0.25000000f, + -0.00000000f, -5.208333333e-03f, -0.66666667f, -1.00000000f, + -0.33333333f, -0.50000000f, -0.16666667f, -0.25000000f}; + +static float nf4_qaunt_map[16] = {-1.0f, + -0.6961928009986877f, + -0.5250730514526367f, + -0.39491748809814453f, + -0.28444138169288635f, + -0.18477343022823334f, + -0.09105003625154495f, + 0.0f, + 0.07958029955625534f, + 0.16093020141124725f, + 0.24611230194568634f, + 0.33791524171829224f, + 0.44070982933044434f, + 0.5626170039176941f, + 0.7229568362236023f, + 1.0f}; + +template +FORCEINLINE T DequantizeOneBnb4(uint8_t x) { + if constexpr (DATA_TYPE == FP4) + return static_cast(fp4_qaunt_map[x]); + else + return static_cast(nf4_qaunt_map[x]); +} + +template +FORCEINLINE void DequantizeBlockBnb4(const uint8_t* src, T* dst, T absmax_block, int32_t block_idx, int32_t numel) { + int32_t block_len = std::min(block_size, numel - block_idx * block_size); + int32_t src_offset = block_idx * block_size / 2; + int32_t dst_offset = block_idx * block_size; + + for (int32_t idx = 0; idx < block_len; idx += 2) { + const uint8_t val = src[src_offset + idx / 2]; + + dst[dst_offset + idx] = DequantizeOneBnb4(val >> 4) * absmax_block; + if (idx + 1 < block_len) dst[dst_offset + idx + 1] = DequantizeOneBnb4(val & 0xF) * absmax_block; + } +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h new file mode 100644 index 0000000000..8811e5649f --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise.h @@ -0,0 +1,174 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "blockwise_quant_block.h" + +#include + +#include "core/common/safeint.h" +#include "core/framework/float16.h" +#include "core/platform/threadpool.h" +#include + +namespace onnxruntime { +namespace contrib { + +template +void QuantizeBlockwise( + uint8_t* dst, // shape: [ N, block_per_K, block_blob_size ] + const T* src, // shape: [K, N] + T* scale, // shape: [N * block_per_K] + uint8_t* zero_points, // shape: [N * block_per_K] if bits > 4 else [(N *block_per_K + 1) / 2] + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + BlockwiseQuantBlock* dst_blob = + reinterpret_cast*>(dst); + + int32_t block_per_K = (K + block_size - 1) / block_size; + int32_t total_block_count = N * block_per_K; + + std::vector zero_points_tmp; // to avoid race condition + (void)zero_points_tmp; + uint8_t* zero_points_tmp_ptr = zero_points; + if (bits <= 4 && zero_points != nullptr) { + zero_points_tmp.resize(total_block_count, 0); + zero_points_tmp_ptr = zero_points_tmp.data(); + } + + concurrency::ThreadPool::TryBatchParallelFor( + thread_pool, + total_block_count, + [&](ptrdiff_t block_idx) { + int32_t n = static_cast(block_idx / block_per_K); + int32_t k_block_idx = static_cast(block_idx % block_per_K); + int32_t k = k_block_idx * block_size; + BlockwiseQuantBlock* blob_ptr = dst_blob + block_idx; + size_t offset = SafeInt(k) * N + n; + if (nullptr != zero_points_tmp_ptr) { + blob_ptr->quant(src + offset, scale[block_idx], zero_points_tmp_ptr[block_idx], k, K, N); + } else { + blob_ptr->quant(src + offset, scale[block_idx], k, K, N); + } + }, + 0); + + if (bits <= 4 && zero_points != nullptr) { // compact zero points + for (int32_t zp_idx = 0; zp_idx < total_block_count / 2; zp_idx++) { + zero_points[zp_idx] = ((zero_points_tmp[zp_idx * 2]) | (zero_points_tmp[zp_idx * 2 + 1] << 4)); + } + if (total_block_count & 1) { + zero_points[total_block_count / 2] = (zero_points[total_block_count / 2] & 0xf0) | zero_points_tmp[total_block_count - 1]; + } + } +} + +#define QuantizeBlockwise4Bits(block_size) \ + QuantizeBlockwise(dst, src, scale, zero_points, N, K, thread_pool); + +template +void QuantizeBlockwise( + uint8_t* dst, // shape: [ N, block_per_K, block_blob_size ] + const T* src, // shape: [K, N] + T* scale, // shape: [N, block_per_K] + uint8_t* zero_points, // shape: [N, block_per_K] + int32_t block_size, + int32_t bits, + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + ORT_ENFORCE(bits == 4, "only 4 bits is supported now"); + + if (16 == block_size) { + QuantizeBlockwise4Bits(16); + } else if (32 == block_size) { + QuantizeBlockwise4Bits(32); + } else if (64 == block_size) { + QuantizeBlockwise4Bits(64); + } else if (128 == block_size) { + QuantizeBlockwise4Bits(128); + } else if (256 == block_size) { + QuantizeBlockwise4Bits(256); + } else { + ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported."); + } +} + +#undef QuantizeBlockwise4Bits + +template +void DequantizeBlockwise( + T* dst, // shape: [N, K] + const uint8_t* src, // shape: [N, block_per_K, block_blob_size] + const T* scale, // shape: [N, block_per_K] + const uint8_t* zero_points, // shape: [N, block_per_K] if bits > 4 else [N, (block_per_K + 1) / 2] + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + int32_t block_per_K = (K + block_size - 1) / block_size; + int32_t task_count = N * block_per_K; + + const BlockwiseQuantBlock* src_blob = + reinterpret_cast*>(src); + + concurrency::ThreadPool::TryBatchParallelFor( + thread_pool, + task_count, + [&](ptrdiff_t task_idx) { + int32_t n = static_cast(task_idx / block_per_K); + int32_t k_block_idx = static_cast(task_idx % block_per_K); + int32_t k = k_block_idx * block_size; + const BlockwiseQuantBlock* blob_ptr = src_blob + task_idx; + size_t offset = SafeInt(n) * K + k; + if (nullptr != zero_points) { + if constexpr (bits > 4) { // zero point is stored with a byte + blob_ptr->dequant(dst + offset, scale[task_idx], zero_points[task_idx], k, K); + } else { // zero points is stored with 4bits + uint8_t zp = zero_points[task_idx / 2]; + zp = (task_idx & 1) ? (zp >> 4) : (zp & 0xf); + blob_ptr->dequant(dst + offset, scale[task_idx], zp, k, K); + } + } else { + blob_ptr->dequant(dst + offset, scale[task_idx], k, K); + } + }, + 0); +} + +#define DequantizeBlockwise4Bits(block_size) \ + DequantizeBlockwise(dst, src, scale, zero_points, N, K, thread_pool); + +template +void DequantizeBlockwise( + T* dst, // [N, K] + const uint8_t* src, // [N, block_per_K, block_blob_size] + const T* scale, // [N, block_per_K] + const uint8_t* zero_points, // [N, block_per_K] + int32_t block_size, + int32_t bits, + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + ORT_ENFORCE(bits == 4, "only 4 bits is supported now"); + + if (16 == block_size) { + DequantizeBlockwise4Bits(16); + } else if (32 == block_size) { + DequantizeBlockwise4Bits(32); + } else if (64 == block_size) { + DequantizeBlockwise4Bits(64); + } else if (128 == block_size) { + DequantizeBlockwise4Bits(128); + } else if (256 == block_size) { + DequantizeBlockwise4Bits(256); + } else { + ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported."); + } +} + +#undef DequantizeBlockwise4Bits + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h new file mode 100644 index 0000000000..5ddb77e5b5 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "blockwise_quant_block_bnb4.h" + +#include + +#include "core/common/safeint.h" +#include "core/framework/float16.h" +#include "core/platform/threadpool.h" +#include + +namespace onnxruntime { +namespace contrib { + +template +void QuantizeBlockwiseBnb4( + uint8_t* dst, // shape: [(N * K + 1) / 2] + const T* src, // shape: [N, K] + T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + int32_t numel = N * K; + int32_t total_block_count = (numel + block_size - 1) / block_size; + + concurrency::ThreadPool::TryBatchParallelFor( + thread_pool, + total_block_count, + [&](ptrdiff_t block_idx) { + QuantizeBlockBnb4( + src, + dst, + absmax[block_idx], + static_cast(block_idx), + numel); + }, + 0); +} + +#define QuantizeBlockwiseBn4DataTyped(block_size, quant_type) \ + if (quant_type == FP4) \ + QuantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); \ + else \ + QuantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); + +template +void QuantizeBlockwiseBnb4( + uint8_t* dst, // shape: [(N * K + 1) / 2] + const T* src, // shape: [N, K] + T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t block_size, + int32_t quant_type, + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + ORT_ENFORCE( + quant_type == FP4 || quant_type == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + if (block_size == 16) { + QuantizeBlockwiseBn4DataTyped(16, quant_type); + } else if (block_size == 32) { + QuantizeBlockwiseBn4DataTyped(32, quant_type); + } else if (block_size == 64) { + QuantizeBlockwiseBn4DataTyped(64, quant_type); + } else if (block_size == 128) { + QuantizeBlockwiseBn4DataTyped(128, quant_type); + } else if (block_size == 256) { + QuantizeBlockwiseBn4DataTyped(256, quant_type); + } else { + ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported."); + } +} + +#undef QuantizeBlockwiseBn4DataTyped + +template +void DequantizeBlockwiseBnb4( + T* dst, // shape: [N, K] + const uint8_t* src, // shape: [(N * K + 1) / 2)] + const T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + int32_t numel = N * K; + int32_t total_block_count = (numel + block_size - 1) / block_size; + + concurrency::ThreadPool::TryBatchParallelFor( + thread_pool, + total_block_count, + [&](ptrdiff_t block_idx) { + DequantizeBlockBnb4( + src, + dst, + absmax[block_idx], + static_cast(block_idx), + numel); + }, + 0); +} + +#define DequantizeBlockwiseBn4DataTyped(block_size, quant_type) \ + if (quant_type == FP4) \ + DequantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); \ + else \ + DequantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); + +template +void DequantizeBlockwiseBnb4( + T* dst, // shape: [N, K] + const uint8_t* src, // shape: [(N * K + 1) / 2)] + const T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t block_size, + int32_t quant_type, + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + ORT_ENFORCE( + quant_type == FP4 || quant_type == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + if (block_size == 16) { + DequantizeBlockwiseBn4DataTyped(16, quant_type); + } else if (block_size == 32) { + DequantizeBlockwiseBn4DataTyped(32, quant_type); + } else if (block_size == 64) { + DequantizeBlockwiseBn4DataTyped(64, quant_type); + } else if (block_size == 128) { + DequantizeBlockwiseBn4DataTyped(128, quant_type); + } else if (block_size == 256) { + DequantizeBlockwiseBn4DataTyped(256, quant_type); + } else { + ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported."); + } +} + +#undef DequantizeBlockwiseBn4DataTyped + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc new file mode 100644 index 0000000000..2f3ede49c3 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc @@ -0,0 +1,109 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/common.h" +#include "dequantize_blockwise_bnb4.h" +#include "core/mlas/inc/mlas.h" + +namespace onnxruntime { +namespace contrib { + +class MatMulBnb4 final : public OpKernel { + public: + MatMulBnb4(const OpKernelInfo& info) : OpKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("quant_type", &quant_type_)); + ORT_ENFORCE( + quant_type_ == FP4 || quant_type_ == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + } + + Status Compute(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t quant_type_; +}; + +Status MatMulBnb4::Compute(OpKernelContext* ctx) const { + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); + + const Tensor* a = ctx->Input(0); + const Tensor* b_quant = ctx->Input(1); + const Tensor* absmax = ctx->Input(2); + + const float* a_data = a->Data(); + const uint8_t* b_quant_data = b_quant->Data(); + const float* absmax_data = absmax->Data(); + + AllocatorPtr allocator; + auto status = ctx->GetTempSpaceAllocator(&allocator); + ORT_RETURN_IF_ERROR(status); + auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); + DequantizeBlockwiseBnb4( + tmp_b_data_ptr.get(), + b_quant_data, + absmax_data, + static_cast(block_size_), + static_cast(quant_type_), + static_cast(N_), + static_cast(K_), + thread_pool); + + constexpr bool transa = false; + constexpr bool transb = true; + TensorShape b_shape({N_, K_}); + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, transa, transb)); + + Tensor* y = ctx->Output(0, helper.OutputShape()); + + // Bail out early if the output is going to be empty + if (y->Shape().Size() == 0) return Status::OK(); + + auto* y_data = y->MutableData(); + + const size_t max_len = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + const size_t lda = helper.Lda(transa); + const size_t ldb = helper.Ldb(transb); + + // TODO: implement with native kernel + std::vector data(max_len); + for (size_t i = 0; i < max_len; i++) { + data[i].BIsPacked = false; + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = 1.f; + data[i].beta = 0.0f; + } + MlasGemmBatch(CblasNoTrans, CblasTrans, M, N, K, data.data(), max_len, thread_pool); + + return Status::OK(); +} + +ONNX_OPERATOR_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc new file mode 100644 index 0000000000..57aada94be --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/common.h" +#include "dequantize_blockwise.h" +#include "core/mlas/inc/mlas.h" + +namespace onnxruntime { +namespace contrib { + +class MatMulNBits final : public OpKernel { + public: + MatMulNBits(const OpKernelInfo& info) : OpKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + } + + Status Compute(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t nbits_; +}; + +Status MatMulNBits::Compute(OpKernelContext* ctx) const { + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); + + const Tensor* a = ctx->Input(0); + const Tensor* b = ctx->Input(1); + const Tensor* scales = ctx->Input(2); + const Tensor* zero_points = ctx->Input(3); + + const auto* a_data = a->Data(); + const uint8_t* b_data = b->Data(); + const auto* scales_data = scales->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + + AllocatorPtr allocator; + auto status = ctx->GetTempSpaceAllocator(&allocator); + ORT_RETURN_IF_ERROR(status); + auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); + DequantizeBlockwise(tmp_b_data_ptr.get(), + b_data, + scales_data, + zero_points_data, + static_cast(block_size_), + static_cast(nbits_), + static_cast(N_), + static_cast(K_), + thread_pool); + +#if 0 // for debug + auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); + MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_); +#endif + + TensorShape b_shape({N_, K_}); + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + + Tensor* y = ctx->Output(0, helper.OutputShape()); + + // Bail out early if the output is going to be empty + if (y->Shape().Size() == 0) + return Status::OK(); + + auto* y_data = y->MutableData(); + + const size_t max_len = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + const size_t lda = helper.Lda(false); + const size_t ldb = helper.Ldb(true); + + // TODO: implement with native kernel + std::vector data(max_len); + for (size_t i = 0; i < max_len; i++) { + data[i].BIsPacked = false; + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = 1.f; + data[i].beta = 0.0f; + } + MlasGemmBatch(CblasNoTrans, CblasTrans, + M, N, K, data.data(), max_len, thread_pool); + + return Status::OK(); +} + +ONNX_OPERATOR_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBits); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc index e86a12d9fb..4e103c2556 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc @@ -20,20 +20,29 @@ namespace contrib { kCpuExecutionProvider, \ KernelDefBuilder() \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - SkipLayerNorm); + SkipLayerNorm); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + SkipSimplifiedLayerNormalization, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + SkipLayerNorm); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(double) -template -SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) +template +SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) { ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); ORT_ENFORCE(epsilon_ >= 0); } -template -Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { +template +Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { const Tensor* input = p_ctx->Input(0); const Tensor* skip = p_ctx->Input(1); const Tensor* gamma = p_ctx->Input(2); @@ -102,10 +111,16 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { } mean = mean / hidden_size; - mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_); + if (simplified) { + mean_square = sqrt(mean_square / hidden_size + epsilon_); + } else { + mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_); + } for (int64_t h = 0; h < hidden_size; h++) { - if (nullptr == beta_data) { + if (simplified) { + p_output[h] = p_output[h] / mean_square * gamma_data[h]; + } else if (nullptr == beta_data) { p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h]; } else { p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h] + beta_data[h]; diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h index 7723541cb6..69edf4609e 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h @@ -10,7 +10,7 @@ namespace onnxruntime { namespace contrib { -template +template class SkipLayerNorm final : public OpKernel { public: SkipLayerNorm(const OpKernelInfo& op_kernel_info); diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index d846f55f1e..626e4c0b87 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -287,9 +287,9 @@ __global__ void AddBiasTransposeQKV(int M, const T* input, const T* biases, T* o T* k_smem = q_smem + rotary_embedding_dim; const int half_rotary_dim = rotary_embedding_dim / 2; - const int half_idx = (head_idx) / half_rotary_dim; - const int intra_half_idx = (head_idx) % half_rotary_dim; - const int smem_pitch = half_rotary_dim; + const int half_idx = (head_idx) / half_rotary_dim; + const int intra_half_idx = (head_idx) % half_rotary_dim; + const int smem_pitch = half_rotary_dim; if (do_rotary) { *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx) = q; @@ -441,7 +441,6 @@ __global__ void AddBiasTransposeQKVLarge(const int head_size, const T* input, co } } - template __global__ void AddBiasTransposeCutlass(const T* input, const T* biases, T* output, int v_head_size) { // Format 3 for cutlass memory efficient attention @@ -651,7 +650,7 @@ void InvokeAddBiasTranspose( if (format != 1 && format != 2 && format != 3) { ORT_THROW("format must be 1, 2 or 3 for rotary attention"); } - if (qk_head_size != 64 && qk_head_size !=128) { + if (qk_head_size != 64 && qk_head_size != 128) { ORT_THROW("qk_head_size must be 64 or 128 for rotary attention"); } if (v_head_size != -1 && qk_head_size != v_head_size) { diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index b8066567fc..2f90bfde89 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -8,6 +8,7 @@ #include "contrib_ops/cuda/bert/attention.h" #include "contrib_ops/cuda/bert/bert_padding.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -39,20 +40,36 @@ REGISTER_KERNEL_TYPED(MLFloat16) template Attention::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionBase(info, false) { - disable_fused_self_attention_ = sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); + disable_fused_self_attention_ = + sizeof(T) != 2 || + ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); - enable_trt_flash_attention_ = sizeof(T) == 2 && - !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false); + enable_trt_flash_attention_ = + sizeof(T) == 2 && + !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false); - enable_fused_causal_attention_ = sizeof(T) == 2 && - ParseEnvironmentVariableWithDefault(attention::kEnableFusedCausalAttention, false); + enable_fused_causal_attention_ = + sizeof(T) == 2 && + ParseEnvironmentVariableWithDefault(attention::kEnableFusedCausalAttention, false); -#if USE_FLASH_ATTENTION - disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); +#if USE_MEMORY_EFFICIENT_ATTENTION + disable_memory_efficient_attention_ = + ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); #else disable_memory_efficient_attention_ = true; #endif + +#if USE_FLASH_ATTENTION + disable_flash_attention_ = + sizeof(T) != 2 || + onnxruntime::ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); + min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault( + attention::kMinSeqLenForFlashAttentionPackedQKV, + attention::kDefaultMinSeqLenForFlashAttentionPackedQKV); +#else + disable_flash_attention_ = true; + min_seq_len_for_flash_attention_packed_qkv_ = 0; +#endif } template @@ -100,71 +117,114 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { MHARunner* fused_runner = nullptr; // Check whether we can use fused kernel - int sm = device_prop.major * 10 + device_prop.minor; - bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN; - bool is_mask_1d_key_seq_len_start = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START; - - if (is_unidirectional_ && enable_fused_causal_attention_) { // GPT - // GPT fused kernels requires left side padding. mask can be: - // none (no padding), 1D sequence lengths or 2d mask. - // Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token - // where past state is empty. - bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING; - bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && - nullptr == relative_position_bias && - parameters.past_sequence_length == 0 && - parameters.hidden_size == parameters.v_hidden_size && - FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, - enable_trt_flash_attention_, true); - if (use_causal_fused_runner) { - // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. - if (nullptr == fused_fp16_runner_.get()) { - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, - enable_trt_flash_attention_, parameters.scale); - } + const int sm = device_prop.major * 10 + device_prop.minor; + const bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN; - // Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check. - fused_runner = fused_fp16_runner_.get(); - } - } else { // BERT - bool use_fused_runner = !disable_fused_self_attention_ && - (nullptr == mask_index || is_mask_1d_seq_len) && - nullptr == past && - nullptr == present && - nullptr == relative_position_bias && - parameters.hidden_size == parameters.v_hidden_size && - FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, - enable_trt_flash_attention_, false); - - if (use_fused_runner) { - // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. - if (nullptr == fused_fp16_runner_.get()) { - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, - enable_trt_flash_attention_, parameters.scale); +#if USE_FLASH_ATTENTION + bool use_flash_attention = !disable_flash_attention_ && + (nullptr == relative_position_bias) && + nullptr == past && + nullptr == present && + parameters.hidden_size == parameters.v_hidden_size && + nullptr == mask_index && + onnxruntime::flash::is_supported(device_prop, + parameters.head_size, + parameters.num_heads, + parameters.num_heads); + // When input is packed QKV format, TensorRT kernel might be faster when sequence length <= 512. + if (use_flash_attention && parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { + use_flash_attention = false; + } + // Allocate buffers + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; + if (use_flash_attention) { + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( + parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + } + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); +#else + constexpr bool use_flash_attention = false; + auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr +#endif + + if (!use_flash_attention) { + if (is_unidirectional_) { // GPT + if (enable_fused_causal_attention_) { + // GPT fused kernels requires left side padding. mask can be: + // none (no padding), 1D sequence lengths or 2d mask. + // Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token + // where past state is empty. + bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING; + bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && + nullptr == relative_position_bias && + parameters.past_sequence_length == 0 && + parameters.hidden_size == parameters.v_hidden_size && + FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, + enable_trt_flash_attention_, true); + if (use_causal_fused_runner) { + // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. + if (nullptr == fused_fp16_runner_.get()) { + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, + enable_trt_flash_attention_, parameters.scale); + } + + // Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check. + fused_runner = fused_fp16_runner_.get(); + } } + } else { // BERT + bool use_fused_runner = !disable_fused_self_attention_ && + (nullptr == mask_index || is_mask_1d_seq_len) && + nullptr == past && + nullptr == present && + nullptr == relative_position_bias && + parameters.hidden_size == parameters.v_hidden_size && + FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, + enable_trt_flash_attention_, false); + + if (use_fused_runner) { + // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. + if (nullptr == fused_fp16_runner_.get()) { + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, + enable_trt_flash_attention_, parameters.scale); + } - // In case some kernel not loaded due to shared memory limit, we need to double check here. - const int S = fused_fp16_runner_->getSFromMaxSeqLen(sequence_length); - if (fused_fp16_runner_->isValid(S)) { - fused_runner = fused_fp16_runner_.get(); + // In case some kernel not loaded due to shared memory limit, we need to double check here. + const int S = fused_fp16_runner_->getSFromMaxSeqLen(sequence_length); + if (fused_fp16_runner_->isValid(S)) { + fused_runner = fused_fp16_runner_.get(); + } } } } -#if USE_FLASH_ATTENTION - bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0; - bool use_memory_efficient_attention = fused_runner == nullptr && - !disable_memory_efficient_attention_ && - (nullptr == mask_index || is_mask_1d_key_seq_len_start) && - nullptr == past && - nullptr == present && - (nullptr == relative_position_bias || is_good_for_rpb) && - (sizeof(T) == 2 || // sequence length threshold is 0 in FP16 - parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32) && - has_memory_efficient_attention(sm, sizeof(T) == 2); +#if USE_MEMORY_EFFICIENT_ATTENTION + bool use_memory_efficient_attention = + !use_flash_attention && + fused_runner == nullptr && + !disable_memory_efficient_attention_ && + nullptr == past && + nullptr == present && + (parameters.head_size & 7) == 0 && + (parameters.v_head_size & 7) == 0 && + (nullptr == mask_index || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && + (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && + has_memory_efficient_attention(sm, sizeof(T) == 2); + + if (use_memory_efficient_attention) { + bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0; + use_memory_efficient_attention = (nullptr == relative_position_bias || is_good_for_rpb); + } #else constexpr bool use_memory_efficient_attention = false; - ORT_UNUSED_PARAMETER(is_mask_1d_key_seq_len_start); #endif cublasHandle_t cublas = GetCublasHandle(context); @@ -199,6 +259,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { parameters.kv_sequence_length, parameters.total_sequence_length, fused_runner, + use_flash_attention, use_fused_cross_attention, use_memory_efficient_attention); auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); @@ -206,27 +267,34 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; AttentionData data; data.gemm_buffer = reinterpret_cast(gemm_buffer.get()); - data.bias = nullptr == bias ? nullptr : reinterpret_cast(bias->Data()); - data.query = nullptr; - data.key = nullptr; - data.value = nullptr; - data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data(); - data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims(); - data.past = (nullptr == past) ? nullptr : reinterpret_cast(past->Data()); - data.past_key = nullptr; - data.past_value = nullptr; - data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); + if (nullptr != bias) { + data.bias = reinterpret_cast(bias->Data()); + } + if (nullptr != mask_index) { + data.mask_index = mask_index->Data(); + data.mask_index_dims = mask_index->Shape().GetDims(); + } + if (nullptr != past) { + data.past = reinterpret_cast(past->Data()); + } + if (nullptr != relative_position_bias) { + data.relative_position_bias = reinterpret_cast(relative_position_bias->Data()); + } data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); - data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData()); - data.present_key = nullptr; - data.present_value = nullptr; + if (nullptr != present) { + data.present = reinterpret_cast(present->MutableData()); + } data.fused_runner = reinterpret_cast(fused_runner); - data.fused_cross_attention_kernel = nullptr; + data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; - data.cumulated_sequence_length_q_cache = nullptr; - data.cumulated_sequence_length_kv_cache = nullptr; + if (softmax_lse_accum_buffer != nullptr) { + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + } + if (out_accum_buffer != nullptr) { + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } return QkvToContext(device_prop, cublas, context->GetComputeStream(), parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.h b/onnxruntime/contrib_ops/cuda/bert/attention.h index ba7c56c04f..455e55ba05 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention.h @@ -21,10 +21,12 @@ class Attention final : public CudaKernel, public AttentionBase { Status ComputeInternal(OpKernelContext* context) const override; protected: + bool disable_flash_attention_; bool disable_fused_self_attention_; bool enable_trt_flash_attention_; bool enable_fused_causal_attention_; bool disable_memory_efficient_attention_; + int min_seq_len_for_flash_attention_packed_qkv_; mutable std::unique_ptr fused_fp16_runner_; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu b/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu deleted file mode 100644 index 5d9cfcc697..0000000000 --- a/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu +++ /dev/null @@ -1,249 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/cuda/cuda_common.h" -#include "contrib_ops/cuda/bert/attention_impl.h" - -using namespace onnxruntime::cuda; - -namespace onnxruntime { -namespace contrib { -namespace cuda { - -template -__global__ void ConcatTensorToTensor(const int tensor_add_sequence_length, - const T* tensor_in, - const T* tensor_add, - T* tensor_out) { - const int h = threadIdx.x; - const int n = threadIdx.y; - const int s = blockIdx.x; - const int b = blockIdx.y; - const int chunk_id = blockIdx.z; - - const int all_sequence_length = gridDim.x; - const int batch_size = gridDim.y; - const int num_heads = blockDim.y; - const int H = blockDim.x; - - // K: number of identical tensors - // tensor_in: K x BxNxPxH - // tensor_add: K x BxNxLxH - // tensor_out: K x BxNxTxH, where T = P + L - const int tensor_in_sequence_length = all_sequence_length - tensor_add_sequence_length; - - const int present_SH = all_sequence_length * H; - const int present_NSH = num_heads * present_SH; - int out_offset = b * present_NSH + n * present_SH + s * H + h + chunk_id * (present_NSH * batch_size); - if (s < tensor_in_sequence_length) { - const int past_SH = tensor_in_sequence_length * H; - const int past_NSH = num_heads * past_SH; - const int in_offset = b * past_NSH + n * past_SH + s * H + h + chunk_id * (past_NSH * batch_size); - tensor_out[out_offset] = tensor_in[in_offset]; - } else if (s < all_sequence_length) { - const int SH = tensor_add_sequence_length * H; - const int NSH = num_heads * SH; - const int in_offset = b * NSH + n * SH + (s - tensor_in_sequence_length) * H + h + chunk_id * (NSH * batch_size); - tensor_out[out_offset] = tensor_add[in_offset]; - } -} - -template -__global__ void ConcatTensorToTensorLarge(const int tensor_add_sequence_length, - const int H, - const T* tensor_in, - const T* tensor_add, - T* tensor_out) { - // Use when (H*)*num_heads > 1024 - int h = threadIdx.x; - const int n = threadIdx.y; - const int s = blockIdx.x; - const int b = blockIdx.y; - const int chunk_id = blockIdx.z; - - const int all_sequence_length = gridDim.x; - const int batch_size = gridDim.y; - const int num_heads = blockDim.y; - const int stride = blockDim.x; - - // K: number of identical tensor - // tensor_in: K x BxNxPxH - // tensor_add: K x BxNxLxH - // tensor_out: K x BxNxTxH - const int tensor_in_sequence_length = all_sequence_length - tensor_add_sequence_length; - - const int present_SH = all_sequence_length * H; - const int present_NSH = num_heads * present_SH; - while (h < H) { - int out_offset = b * present_NSH + n * present_SH + s * H + h + chunk_id * (present_NSH * batch_size); - if (s < tensor_in_sequence_length) { - const int past_SH = tensor_in_sequence_length * H; - const int past_NSH = num_heads * past_SH; - const int in_offset = b * past_NSH + n * past_SH + s * H + h + chunk_id * (past_NSH * batch_size); - tensor_out[out_offset] = tensor_in[in_offset]; - } else if (s < all_sequence_length) { - const int SH = tensor_add_sequence_length * H; - const int NSH = num_heads * SH; - const int in_offset = b * NSH + n * SH + (s - tensor_in_sequence_length) * H + h + chunk_id * (NSH * batch_size); - tensor_out[out_offset] = tensor_add[in_offset]; - } - - h += stride; - } -} - -Status LaunchConcatTensorToTensor(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const int matrix_num, - const float* tensor_in, - const float* tensor_add, - float* tensor_out) { - const dim3 grid(all_sequence_length, batch_size, matrix_num); - if (0 == (head_size & 1)) { - const int H = head_size / 2; - if (H * num_heads <= max_threads_per_block) { - const dim3 block(H, num_heads, 1); - ConcatTensorToTensor<<>>(sequence_length, - reinterpret_cast(tensor_in), - reinterpret_cast(tensor_add), - reinterpret_cast(tensor_out)); - } else { - const dim3 block(max_threads_per_block / num_heads, num_heads, 1); - ConcatTensorToTensorLarge<<>>(sequence_length, - H, - reinterpret_cast(tensor_in), - reinterpret_cast(tensor_add), - reinterpret_cast(tensor_out)); - } - } else { - if (head_size * num_heads <= max_threads_per_block) { - const dim3 block(head_size, num_heads, 1); - ConcatTensorToTensor<<>>(sequence_length, tensor_in, tensor_add, tensor_out); - } else { - const dim3 block(max_threads_per_block / num_heads, num_heads, 1); - ConcatTensorToTensorLarge<<>>(sequence_length, - head_size, - tensor_in, - tensor_add, - tensor_out); - } - } - return CUDA_CALL(cudaGetLastError()); -} - -Status LaunchConcatTensorToTensor(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const int matrix_num, - const half* tensor_in, - const half* tensor_add, - half* tensor_out) { - const dim3 grid(all_sequence_length, batch_size, matrix_num); - if (0 == (head_size % 4)) { - const int H = head_size / 4; - if (H * num_heads <= max_threads_per_block) { - const dim3 block(H, num_heads, 1); - ConcatTensorToTensor<<>>(sequence_length, - reinterpret_cast(tensor_in), - reinterpret_cast(tensor_add), - reinterpret_cast(tensor_out)); - } else { - const dim3 block(max_threads_per_block / num_heads, num_heads, 1); - ConcatTensorToTensorLarge<<>>(sequence_length, - H, - reinterpret_cast(tensor_in), - reinterpret_cast(tensor_add), - reinterpret_cast(tensor_out)); - } - } else if (0 == (head_size & 1)) { - const int H = head_size / 2; - if (H * num_heads <= max_threads_per_block) { - const dim3 block(H, num_heads, 1); - ConcatTensorToTensor<<>>(sequence_length, - reinterpret_cast(tensor_in), - reinterpret_cast(tensor_add), - reinterpret_cast(tensor_out)); - } else { - const dim3 block(max_threads_per_block / num_heads, num_heads, 1); - ConcatTensorToTensorLarge<<>>(sequence_length, - H, - reinterpret_cast(tensor_in), - reinterpret_cast(tensor_add), - reinterpret_cast(tensor_out)); - } - } else { // this should be an "odd" case. probably not worth catching it in the half2 kernel. - if (head_size * num_heads <= max_threads_per_block) { - const dim3 block(head_size, num_heads, 1); - ConcatTensorToTensor<<>>(sequence_length, tensor_in, tensor_add, tensor_out); - } else { - const dim3 block(max_threads_per_block / num_heads, num_heads, 1); - ConcatTensorToTensorLarge<<>>(sequence_length, - head_size, - tensor_in, - tensor_add, - tensor_out); - } - } - return CUDA_CALL(cudaGetLastError()); -} - -Status LaunchConcatPastToPresent(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const float* past, - const float* k_v, - float* present) { - return LaunchConcatTensorToTensor( - stream, - all_sequence_length, - sequence_length, - batch_size, - head_size, - num_heads, - max_threads_per_block, - 2, - past, - k_v, - present); -} - -Status LaunchConcatPastToPresent(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const half* past, - const half* k_v, - half* present) { - return LaunchConcatTensorToTensor( - stream, - all_sequence_length, - sequence_length, - batch_size, - head_size, - num_heads, - max_threads_per_block, - 2, - past, - k_v, - present); -} - -} // namespace cuda -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 4d478ef158..83c426e7e6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -26,22 +26,19 @@ limitations under the License. // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include -#include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" -#include "contrib_ops/cuda/bert/attention_impl.h" #include "contrib_ops/cuda/bert/attention_softmax.h" #include "contrib_ops/cuda/bert/transformer_common.h" -#include "contrib_ops/cuda/bert/add_bias_transpose.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h" #include "contrib_ops/cpu/bert/attention_base.h" #include "contrib_ops/cuda/bert/bert_padding.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cuda/bert/attention_impl.h" using namespace onnxruntime::cuda; using namespace onnxruntime::contrib::attention_softmax_cuda; @@ -64,7 +61,8 @@ size_t AlignSize(size_t bytes) { void CumulatedSequenceLengthCache::Initialize(int32_t sequence_length, cudaStream_t stream) { if (this->sequence_length != sequence_length) { ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0); - LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr, this->max_batch_size, sequence_length, stream); + LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr, + this->max_batch_size, sequence_length, stream); this->sequence_length = sequence_length; } } @@ -114,6 +112,7 @@ size_t GetAttentionWorkspaceSize( size_t kv_sequence_length, size_t total_sequence_length, void* fused_runner, + bool use_flash_attention, bool use_fused_cross_attention, bool use_memory_efficient_attention) { // Note that q, k and v might need alignment for fused attention kernels. @@ -121,6 +120,14 @@ size_t GetAttentionWorkspaceSize( ((sequence_length + kv_sequence_length) * qk_head_size + kv_sequence_length * v_head_size); #if USE_FLASH_ATTENTION + if (use_flash_attention) { + return qkv_bytes + onnxruntime::flash::get_softmax_lse_size(sequence_length, batch_size, num_heads); + } +#else + ORT_UNUSED_PARAMETER(use_flash_attention); +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION if (use_memory_efficient_attention) { size_t fmha_buffer_bytes = 0; if (MemoryEfficientAttentionParams::need_workspace(v_head_size, element_size == sizeof(float))) { @@ -146,765 +153,290 @@ size_t GetAttentionWorkspaceSize( } template -__global__ void AddBiasTransAppendKvToPresentSmall( - const T* qkv, const T* biases, T* present, - const int head_size, const int past_sequence_length, const int max_sequence_length) { - // Input: BxSxMxNxH (Format 1) - // Output: (2, B, N, [P..P+S) of MaxS, H), - // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size - const int n = threadIdx.y; - const int s = blockIdx.x; - const int b = blockIdx.y; - const int N = blockDim.y; - const int S = gridDim.x; - const int B = gridDim.y; - - constexpr int M = 3; // Matrix count in qkv - const int m = blockIdx.z + 1; // k = 1, v = 2 - - const int NH = N * head_size; - const int NHS = NH * S; - - qkv += (n * head_size + (s * M + m) * NH + b * M * NHS); - if (biases) { - biases += (m * NH + n * head_size); - } +Status FusedTrtCrossAttention( + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data) { + assert(data.qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H); - const int MsH = max_sequence_length * head_size; - const int NMsH = N * MsH; - const int BNMsH = B * NMsH; - present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH); + // We only enable fused cross attention when there is no key padding mask. + // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. + assert(data.mask_index == nullptr); - for (int h = threadIdx.x; h < head_size; h += blockDim.x) { - T bias = (biases ? biases[h] : (T)0.0f); - present[h] = qkv[h] + bias; - } -} + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + int* q_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, + data.mask_index, batch_size, + sequence_length, stream, + data.scratch); -template -__global__ void AddBiasTransAppendKvToPresent( - const T* qkv, const T* biases, T* present, - const int head_size, const int past_sequence_length, const int max_sequence_length) { - // Input: BxSxMxNxH (Format 1) - // Output: (2, B, N, [P..P+S) of MaxS, H), - // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size - const int n = blockIdx.x; - const int s = blockIdx.y; - const int b = (blockIdx.z >> 1); - const int N = gridDim.x; - const int S = gridDim.y; - const int B = (gridDim.z >> 1); - - constexpr int M = 3; // Matrix count in qkv - const int m = (blockIdx.z & 0x1) + 1; // k = 1, v = 2 - - const int NH = N * head_size; - const int NHS = NH * S; - - qkv += (n * head_size + (s * M + m) * NH + b * M * NHS); - if (biases) { - biases += (m * NH + n * head_size); - } + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1); - const int MsH = max_sequence_length * head_size; - const int NMsH = N * MsH; - const int BNMsH = B * NMsH; - present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH); + int* kv_sequence_offset = q_sequence_offset + (GetSequenceOffsetSize(batch_size, false) / sizeof(int)); + kv_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_kv_cache, + data.mask_index, batch_size, parameters.kv_sequence_length, stream, + kv_sequence_offset); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); - for (int h = threadIdx.x; h < head_size; h += blockDim.x) { - T bias = (biases ? biases[h] : (T)0.0f); - present[h] = qkv[h] + bias; - } -} + DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1); -// qkv buffer is merged tensor of shape (B,S,3,N,H), k v is the second/third of the 3. -// bias is of shape (3, NxH) or nullptr -// append to present of (2, B, N, (P..T) of M, H), -template -Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, - const int max_sequence_length, - const int past_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const T* biases, - const T* qkv_buffer, - T* present) { - assert(head_size <= (1 << 30)); - - int64_t nh = (int64_t)head_size * num_heads; - if (nh <= max_threads_per_block) { - const dim3 grid(sequence_length, batch_size, 2); // 2 for k and v - const dim3 block(max_threads_per_block / num_heads, num_heads, 1); - - AddBiasTransAppendKvToPresentSmall<<>>( - qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length); - } else { - const dim3 grid(num_heads, sequence_length, batch_size * 2); // 2 for k and v - const dim3 block(std::min(head_size, max_threads_per_block), 1, 1); - AddBiasTransAppendKvToPresent<<>>( - qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length); + FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel = + reinterpret_cast(data.fused_cross_attention_kernel); + + // When there is no bias, we can directly use q and packed kv from inputs. + void const* query = data.q; + void const* packed_kv = data.k; + if (data.value == nullptr && data.bias == nullptr) { + query = data.query; + packed_kv = data.key; } - return CUDA_CALL(cudaGetLastError()); + run_fused_cross_attention( + query, // Q + packed_kv, // packed KV + q_sequence_offset, // cumulated sequence length of Q + kv_sequence_offset, // cumulated sequence length of KV + data.output, // output + cross_attention_kernel, // kernels + batch_size, // batch size + parameters.num_heads, // number of heads + parameters.head_size, // head size of Q/K/V + sequence_length, // sequence length of Q + parameters.kv_sequence_length, // sequence length of KV + stream); + + DUMP_TENSOR("trt cross output", data.output, + batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); + return Status::OK(); } -template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, - const int max_sequence_length, - const int total_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const float* bias, - const float* qkv_buffer, - float* present); - -template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, - const int max_sequence_length, - const int total_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const half* bias, - const half* qkv_buffer, - half* present); +template <> +Status FusedTrtCrossAttention( + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, + "Trt fused cross attention does not support float tensor"); +} template -Status PrepareQkv(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { +Status FusedTrtSelfAttention( + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data) { const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - const bool past_present_share_buffer = parameters.past_present_share_buffer; - void* fused_runner = data.fused_runner; - bool use_memory_efficient_attention = data.use_memory_efficient_attention; + const bool causal = parameters.is_unidirectional; - T* qkv = data.workspace; + int* sequence_offset = reinterpret_cast(data.scratch); - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); - - // Default format for memory efficient attention. - // When there is past state, the format shall be BxNxSxH, so we disable memory efficient attention when there is past. DUMP_TENSOR_INIT(); - if (nullptr != data.gemm_buffer) { - if (data.bias == nullptr) { - assert(nullptr == fused_runner); - // For quantized attention, bias has been added so only need transpose here. - // gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH - assert(qk_head_size == v_head_size); - int matrix_to_trans = (past_present_share_buffer ? 1 : 3); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.gemm_buffer, qkv, 3)); - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; - } else { - // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2) - // For memory efficient attention, transpose to 3xBxSxNxH (format 3) - // For unfused kernel, transpose to 3xBxNxSxH (format 1) - // For fused causal kernel, use format 1 since we need have K and V to update present state, - // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel. - const int format = (use_fused_kernel ? 2 : (use_memory_efficient_attention ? 3 : 1)); - qkv_format = use_fused_kernel - ? AttentionQkvFormat::QKV_BSN3H - : (use_memory_efficient_attention - ? AttentionQkvFormat::Q_K_V_BSNH - : (use_fused_causal ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH : AttentionQkvFormat::Q_K_V_BNSH)); - - // For fused causal, we will update gemm_buffer with bias directly. - T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr; - - int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3); - // format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v - // format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H) - LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias, - 3, parameters.do_rotary, parameters.past_sequence_length); - } + if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { + DUMP_TENSOR_D("mask", reinterpret_cast(data.mask_index), batch_size, sequence_length); + LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream); + } else { + sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, + data.mask_index, batch_size, sequence_length, stream, + sequence_offset); } - // attention with past/present state - else if (data.past_key != nullptr || data.present_key != nullptr) { - // Below logic does not support memory efficient attention with past (like pass_past_in_kv) but without bias - if (data.bias == nullptr) { - // cross attention with past state - if (data.past_key != nullptr && data.present_key == nullptr) { - assert(data.past_value != nullptr); - assert(data.query != nullptr); - assert(data.key == nullptr); - assert(data.value == nullptr); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - } - // cross attention with present state or self attention with present state - else if (data.past_key == nullptr && data.present_key != nullptr) { - assert(data.past_value == nullptr); - assert(data.present_value != nullptr); - assert(data.query != nullptr); - assert(data.key != nullptr); - assert(data.value != nullptr); - - // TODO: supporting packed qkv for self attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - - // TODO: supporting packed kv for cross attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.key, data.present_key)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.value, data.present_value)); - } - // self attention with past and present state - else { - assert(data.past_key != nullptr); - assert(data.past_value != nullptr); - assert(data.present_key != nullptr); - assert(data.present_value != nullptr); - assert(data.query != nullptr); - assert(data.key != nullptr); - assert(data.value != nullptr); - // TODO: supporting packed qkv for self attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.key, k)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.value, v)); - } - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; - } -#if USE_FLASH_ATTENTION - // When past_key/past_value are inputted directly as key/value and there is no present_key/present_value - else if (use_memory_efficient_attention && data.past_key != nullptr && data.past_value != nullptr && parameters.pass_past_in_kv) { - // Transpose past_key and past_value to use memory efficient attention - - // past_key (BxNxSxH) => temp_k_workspace (BxSxNxH) - ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.past_key, data.temp_k_workspace)); - // past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) - ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.past_value, data.temp_v_workspace)); - - // query => q, temp_k_workspace => k, temp_v_workspace => v - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - - data.past_key = nullptr; - data.past_value = nullptr; - } - // When there is no past_key/past_value and there is present_key/present_value (e.g. get initial kv to use as past_kv in the next iteration) - else if (use_memory_efficient_attention && data.present_key != nullptr && data.present_value != nullptr) { - // Use memory efficient attention kernel - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.key, data.value, q, data.temp_k_workspace, data.temp_v_workspace); - - // temp_k_workspace (BxSxNxH) => present_k (BxNxSxH) - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.temp_k_workspace, data.present_key)); - - // temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v) - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.temp_v_workspace, data.present_value)); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size * kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", data.temp_v_workspace, batch_size * kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } -#endif - else { - // Use unfused kernel for Q, use unfused kernel for K and V if needed - constexpr int format = 0; - // Query (BxSxNxH) => Q (BxNxSxH) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, q, - true, -1); - - if (!parameters.pass_past_in_kv) { - T* k_dest = (data.past_key == nullptr && data.present_key != nullptr) ? data.present_key : k; - T* v_dest = (data.past_value == nullptr && data.present_value != nullptr) ? data.present_value : v; - - // Key (BxLxNxH) => K (BxNxLxH) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, data.bias + num_heads * qk_head_size, k_dest, - true, -1); - - // Value (BxLxNxH_v) => V (BxNxLxH_v) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, v_head_size, - data.value, data.bias + 2 * num_heads * qk_head_size, v_dest, - true, -1); - - DUMP_TENSOR_D("q(BNSH)", q, batch_size * num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k(BNSH)", k_dest, batch_size * num_heads, kv_sequence_length, qk_head_size); - DUMP_TENSOR_D("v(BNSH)", v_dest, batch_size * num_heads, kv_sequence_length, v_head_size); - } - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; - } - } else if (data.key == nullptr) { // gemm_buffer == nullptr and packed qkv - assert(data.bias == nullptr); - assert(qk_head_size == v_head_size); + DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); - DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size); - - if (use_memory_efficient_attention) { - // unpack qkv to BSNH. Note that there is no bias so we need not output query to q. - constexpr int format = 4; - T* qkv_add_bias = nullptr; - LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, qkv, - true, v_head_size, qkv_add_bias, 3); - DUMP_TENSOR_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } else { - if (!use_fused_kernel) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "packed QKV format is not implemented for current GPU. Please disable it in fusion options."); - } - - qkv_format = AttentionQkvFormat::QKV_BSN3H; - } - } else if (data.value == nullptr) { // gemm_buffer == nullptr and packed kv - // TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint. - // CheckInputs verified this constraint. - assert(data.bias == nullptr); - assert(qk_head_size == v_head_size); + FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(data.fused_runner); - DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size); - - if (use_memory_efficient_attention) { - // unpack kv to BSNH. Note that there is no bias so we need not output query to q. - constexpr int format = 4; - T* qkv_add_bias = nullptr; - const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size); - LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, kv_bias, k, - true, v_head_size, qkv_add_bias, 2); - DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } else { - if (data.fused_cross_attention_kernel == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "packed KV format is not implemented for current GPU. Please disable packed kv in fusion options."); - } - - qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; - } - } else { // gemm_buffer == nullptr and not packed - assert(data.query != nullptr && data.key != nullptr && data.value != nullptr); - - DUMP_TENSOR_D("query", data.query, batch_size * sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("key", data.key, batch_size * kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("value", data.value, batch_size * kv_sequence_length, num_heads, v_head_size); - -#if DUMP_TENSOR_LEVEL > 1 - if (data.bias != nullptr) { - DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size); - DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); - DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); - } -#endif + const int S = causal ? sequence_length : fused_fp16_runner->getSFromMaxSeqLen(sequence_length); - if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) { - DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, num_heads, sequence_length, kv_sequence_length); - } + // B = 2 * batch_size when there is padding in input, and B = batch_size when padding is removed. + const int B = (nullptr == data.mask_index ? batch_size : 2 * batch_size); - if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) { - DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1); + fused_fp16_runner->setup(S, B); + + if (!causal) { + assert(data.qkv_format == AttentionQkvFormat::QKV_BSN3H); + + // When there is no bias, we can directly use packed qkv from inputs. + void const* packed_qkv = data.q; + if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) { + packed_qkv = data.query; } - if (data.fused_cross_attention_kernel != nullptr) { - assert(qk_head_size == v_head_size); + fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream); + DUMP_TENSOR("fused output", data.output, + batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); + } else { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); + fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream); + DUMP_TENSOR("fused causal output", data.output, + batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); + } + return Status::OK(); +} - // For fused cross attention, besides adding bias, K and V needed to be packed: - // K (BxSxNxH), V (BxSxNxH) => BxSxNx2xH - LaunchAddBiasTransposeTrt( - stream, max_threads_per_block, - batch_size, sequence_length, - num_heads, qk_head_size, - data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length); +// Template Specialization for float type +template <> +Status FusedTrtSelfAttention( + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, + "Trt fused attention does not support float tensor"); +} - qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; - } #if USE_FLASH_ATTENTION - else if (use_memory_efficient_attention) { - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.key, data.value, q, k, v); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } +template +Status FlashAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(nullptr == data.mask_index); + assert(nullptr == data.relative_position_bias); + assert(parameters.head_size == parameters.v_head_size); + + void* query = reinterpret_cast(data.q); + void* key = reinterpret_cast(data.k); + void* value = reinterpret_cast(data.v); + // For packed KV, we can use query input directly. + if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr && data.bias == nullptr) { + query = reinterpret_cast(const_cast(data.query)); + } + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size); + DUMP_TENSOR_D("k(BSNH)", data.k, + parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.head_size); + DUMP_TENSOR_D("v(BSNH)", data.v, + parameters.batch_size, parameters.total_sequence_length, + parameters.num_heads, parameters.v_head_size); + + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( + device_prop, stream, query, key, value, data.output, reinterpret_cast(data.scratch), + parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, + parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, + parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), + true)); + + DUMP_TENSOR("flash attention output", data.output, + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); + + return Status::OK(); +} + +template <> +Status FlashAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "flash attention does not support float tensor"); +} #endif - else if (use_fused_kernel) { - assert(qk_head_size == v_head_size); - - // Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H) - LaunchAddBiasTransposeTrt( - stream, max_threads_per_block, - batch_size, sequence_length, - num_heads, qk_head_size, - data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length); - DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size); - - qkv_format = AttentionQkvFormat::QKV_BSN3H; - } else { // unfused kernel - ORT_ENFORCE(!use_fused_causal, "MultiHeadAttention has not enabled fused causal"); - - // Query (BxSxNxH) => Q (BxNxSxH) - constexpr int format = 0; - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, q, - true, -1); - - // Key (BxLxNxH) => K (BxNxLxH) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k, - true, -1); - - // Value (BxLxNxH_v) => K (BxNxLxH_v) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, v_head_size, - data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v, - true, -1); - - DUMP_TENSOR_D("q(BNSH)", q, batch_size * num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k(BNSH)", k, batch_size * num_heads, kv_sequence_length, qk_head_size); - DUMP_TENSOR_D("v(BNSH)", v, batch_size * num_heads, kv_sequence_length, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; - } + +#if USE_MEMORY_EFFICIENT_ATTENTION +template +Status EfficientAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + // We only enable fused cross attention when there is no key padding mask. + // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + + const void* query = data.q; + const void* key = data.k; + const void* value = data.v; + // For packed KV, we can use query input directly. + if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) { + assert(data.bias == nullptr); + query = data.query; } - CUDA_RETURN_IF_ERROR(cudaGetLastError()); + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size); + DUMP_TENSOR_D("k(BSNH)", data.k, + parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.head_size); + DUMP_TENSOR_D("v(BSNH)", data.v, + parameters.batch_size, parameters.total_sequence_length, + parameters.num_heads, parameters.v_head_size); + + MemoryEfficientAttentionParams p; + p.sm = device_prop.major * 10 + device_prop.minor; + p.is_half = sizeof(T) == 2; + p.batch_size = parameters.batch_size; + p.num_heads = parameters.num_heads; + p.sequence_length = parameters.sequence_length; + p.kv_sequence_length = parameters.total_sequence_length; + p.max_sequence_length = parameters.total_sequence_length; + p.qk_head_size = parameters.head_size; + p.v_head_size = parameters.v_head_size; + p.causal = parameters.is_unidirectional; + p.scale = scale; + p.seqlen_k_ptr = nullptr == data.mask_index + ? nullptr + : const_cast(reinterpret_cast(data.mask_index)); + p.seqstart_q_ptr = nullptr == data.mask_index + ? nullptr + : const_cast(reinterpret_cast( + data.mask_index + parameters.batch_size)); + p.seqstart_k_ptr = nullptr == data.mask_index + ? nullptr + : const_cast(reinterpret_cast( + data.mask_index + 2 * parameters.batch_size + 1)); + p.query = query; + p.key = key; + p.value = value; + p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; + p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; + p.output = data.output; + p.is_kv_bsnh = true; + p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float)) + ? data.scratch + : nullptr; + p.stream = stream; + p.has_custom_right_padding = false; + run_memory_efficient_attention(p); + DUMP_TENSOR("efficient attention output", data.output, + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); + return Status::OK(); } +#endif template -Status QkvToContext( +Status UnfusedAttention( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, Stream* ort_stream, contrib::AttentionParameters& parameters, - AttentionData& data) { + AttentionData& data, + float scale) { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH); + auto stream = static_cast(ort_stream->GetHandle()); - constexpr size_t element_size = sizeof(T); - const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; const int total_sequence_length = parameters.total_sequence_length; const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; const int v_head_size = parameters.v_head_size; - const bool past_present_share_buffer = parameters.past_present_share_buffer; - const float mask_filter_value = parameters.mask_filter_value; - void* fused_runner = data.fused_runner; - - // At most one fused kernel is enabled. - assert(int(data.use_memory_efficient_attention) + int(fused_runner != nullptr) + int(data.fused_cross_attention_kernel != nullptr) <= 1); - const int batches = batch_size * num_heads; - T* qkv = nullptr; - T* q = nullptr; - T* k = nullptr; - T* v = nullptr; - T* scratch1 = data.workspace; - if (data.has_qkv_workspace) { - const int size_per_batch_q = sequence_length * qk_head_size; - const int size_per_batch_k = kv_sequence_length * qk_head_size; - const int size_per_batch_v = kv_sequence_length * v_head_size; - const size_t elements_q = static_cast(batches) * static_cast(size_per_batch_q); - const size_t elements_k = static_cast(batches) * static_cast(size_per_batch_k); - const size_t elements_v = static_cast(batches) * static_cast(size_per_batch_v); - qkv = data.workspace; - q = qkv; - k = q + elements_q; - v = k + elements_k; - scratch1 = v + elements_v; - } - - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); - - AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(PrepareQkv(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format)); - - int present_size_per_batch_k = 0; - int present_size_per_batch_v = 0; - if (!past_present_share_buffer) { - // Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length. - // past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH) - // past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH) - // When there is past state, the head size for Q/K/V shall be same: H == H_v. - present_size_per_batch_k = total_sequence_length * qk_head_size; - present_size_per_batch_v = total_sequence_length * v_head_size; - - if (nullptr != data.present) { - assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH || qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); - ORT_RETURN_IF_ERROR( - LaunchConcatPastToPresent(stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, data.past, k, data.present)); - - // Update pointers to present_k and present_v. - k = data.present; - v = data.present + batches * present_size_per_batch_k; - } - - if (nullptr != data.past_key || nullptr != data.present_key) { - if (nullptr != data.past_key && nullptr == data.present_key) { - k = const_cast(data.past_key); - v = const_cast(data.past_value); - } else if (nullptr == data.past_key && nullptr != data.present_key) { - if (qkv_format == AttentionQkvFormat::Q_K_V_BNSH) { - k = data.present_key; - v = data.present_value; - } else { - assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH); - k = data.temp_k_workspace; - v = data.temp_v_workspace; - } - } else if (parameters.pass_past_in_kv) { - // past_key and past_value are used directly as key and value in attention computations - k = const_cast(data.past_key); - v = const_cast(data.past_value); - - // This path has a memory copy from past_key and past_value to present_key and present_value - // Avoid this path since the memory copy is unnecessary because past_key == present_key and - // past_value == present_value - int64_t k_size = (int64_t)batch_size * num_heads * parameters.total_sequence_length * qk_head_size; - int64_t v_size = (int64_t)batch_size * num_heads * parameters.total_sequence_length * v_head_size; - cudaMemcpyAsync(data.present_key, data.past_key, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - cudaMemcpyAsync(data.present_value, data.past_value, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - } else { - ORT_RETURN_IF_ERROR( - LaunchConcatTensorToTensor(stream, parameters.total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, 1, data.past_key, k, data.present_key)); - ORT_RETURN_IF_ERROR( - LaunchConcatTensorToTensor(stream, parameters.total_sequence_length, sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, 1, data.past_value, v, data.present_value)); - // Update pointers to present_k and present_v. - k = data.present_key; - v = data.present_value; - } - } - } else { - assert(qk_head_size == v_head_size); - assert(data.fused_cross_attention_kernel == nullptr); - assert(!use_fused_kernel); - assert(data.gemm_buffer != nullptr); - assert(!data.use_memory_efficient_attention); - assert(data.has_qkv_workspace); - - if (nullptr != data.past_key || nullptr != data.present_key) { - // TODO: support this case. - ORT_THROW("buffer sharing for no bias case between past and present is not supported yet."); - } - - if (data.present != data.past) { - // For easy testing. Production should better avoid this path. - int64_t kv_size = 2LL * (int64_t)batch_size * num_heads * parameters.max_sequence_length * qk_head_size; - cudaMemcpyAsync(data.present, data.past, kv_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - } - - // append last k v to present - ORT_RETURN_IF_ERROR(LaunchAddBiasTransAppendKvToPresent( - stream, parameters.max_sequence_length, parameters.past_sequence_length, sequence_length, - batch_size, qk_head_size, num_heads, max_threads_per_block, - use_fused_causal ? nullptr : data.bias, // For fused causal, bias has been added to gemm_buffer - data.gemm_buffer, data.present)); - - present_size_per_batch_k = parameters.max_sequence_length * qk_head_size; - present_size_per_batch_v = present_size_per_batch_k; - k = data.present; - v = data.present + batches * present_size_per_batch_k; - } - - // Q, K and V are ready now - DUMP_TENSOR_INIT(); - - if (data.fused_cross_attention_kernel != nullptr) { - assert(qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H); - - // We only enable fused cross attention when there is no key padding mask. - // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. - assert(data.mask_index == nullptr); - - int* q_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, - data.mask_index, batch_size, sequence_length, stream, - scratch1); - - DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1); - - int* kv_sequence_offset = q_sequence_offset + (GetSequenceOffsetSize(batch_size, false) / sizeof(int)); - kv_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_kv_cache, - data.mask_index, batch_size, kv_sequence_length, stream, - kv_sequence_offset); - CUDA_RETURN_IF_ERROR(cudaGetLastError()); - - DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1); - - FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel = - reinterpret_cast(data.fused_cross_attention_kernel); - - // When there is no bias, we can directly use q and packed kv from inputs. - void const* query = q; - void const* packed_kv = k; - if (data.value == nullptr && data.bias == nullptr) { - query = data.query; - packed_kv = data.key; - } - - run_fused_cross_attention( - query, // Q - packed_kv, // packed KV - q_sequence_offset, // cumulated sequence length of Q - kv_sequence_offset, // cumulated sequence length of KV - data.output, // output - cross_attention_kernel, // kernels - batch_size, // batch size - num_heads, // number of heads - qk_head_size, // head size of Q/K/V - sequence_length, // sequence length of Q - kv_sequence_length, // sequence length of KV - stream); - - DUMP_TENSOR("trt cross output", data.output, batch_size * sequence_length, num_heads, v_head_size); - return Status::OK(); - } - - // Run TRT fused attention. - if (use_fused_kernel || use_fused_causal) { - int* sequence_offset = reinterpret_cast(scratch1); - if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { - DUMP_TENSOR_D("mask", reinterpret_cast(data.mask_index), batch_size, sequence_length); - LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream); - } else { - sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, - data.mask_index, batch_size, sequence_length, stream, - sequence_offset); - } - DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1); - CUDA_RETURN_IF_ERROR(cudaGetLastError()); - - FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(fused_runner); - - const int S = use_fused_causal ? sequence_length : fused_fp16_runner->getSFromMaxSeqLen(sequence_length); - - // B = 2 * batch_size when there is padding in input, and B = batch_size when padding is removed. - const int B = (nullptr == data.mask_index ? batch_size : 2 * batch_size); - - fused_fp16_runner->setup(S, B); - - if (use_fused_kernel) { - assert(qkv_format == AttentionQkvFormat::QKV_BSN3H); - - // When there is no bias, we can directly use packed qkv from inputs. - void const* packed_qkv = qkv; - if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) { - packed_qkv = data.query; - } - - fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream); - DUMP_TENSOR("fused output", data.output, batch_size * sequence_length, num_heads, v_head_size); - } else { - assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); - fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream); - DUMP_TENSOR("fused causal output", data.output, batch_size * sequence_length, num_heads, v_head_size); - } - return Status::OK(); - } - - // For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation. - const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) - : parameters.scale; - -#if USE_FLASH_ATTENTION - if (data.use_memory_efficient_attention) { - // We only enable fused cross attention when there is no key padding mask. - // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. - assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH); - - const void* query = q; - const void* key = k; - const void* value = v; - // For packed KV, we can use query input directly. - if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) { - assert(data.bias == nullptr); - query = data.query; - } - - DUMP_TENSOR_D("attention q(BSNH)", q, batch_size * sequence_length, num_heads * qk_head_size); - DUMP_TENSOR_D("attention k(BSNH)", k, batch_size * sequence_length, num_heads * qk_head_size); - DUMP_TENSOR_D("attention v(BSNH)", v, batch_size * sequence_length, num_heads * v_head_size); - - MemoryEfficientAttentionParams p; - p.sm = device_prop.major * 10 + device_prop.minor; - p.is_half = sizeof(T) == 2; - p.batch_size = parameters.batch_size; - p.num_heads = parameters.num_heads; - p.sequence_length = parameters.sequence_length; - p.kv_sequence_length = parameters.total_sequence_length; - p.qk_head_size = parameters.head_size; - p.v_head_size = parameters.v_head_size; - p.causal = parameters.is_unidirectional; - p.scale = scale; - p.seqlen_k_ptr = nullptr == data.mask_index ? nullptr : const_cast(reinterpret_cast(data.mask_index)); - p.seqstart_q_ptr = nullptr == data.mask_index ? nullptr : const_cast(reinterpret_cast(data.mask_index + batch_size)); - p.seqstart_k_ptr = nullptr == data.mask_index ? nullptr : const_cast(reinterpret_cast(data.mask_index + 2 * batch_size + 1)); - p.query = query; - p.key = key; - p.value = value; - p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; - p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; - p.output = data.output; - p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? scratch1 : nullptr; - p.stream = stream; - run_memory_efficient_attention(p); - DUMP_TENSOR("attention cutlass output", data.output, batch_size * sequence_length, num_heads, v_head_size); - return Status::OK(); - } -#endif - - // The following are unfused attention. - assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH); const int* mask_index = data.mask_index; gsl::span& mask_index_dims = data.mask_index_dims; // Raw attention mask could be 2D (BxT) or 3D (BxSxT) or 4D(Bx1xMxM), where M is the max sequence length. bool use_raw_attention_mask = (nullptr != mask_index && mask_index_dims.size() >= 2); - // Compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxT + // Compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch: BxNxSxT // Q: BxNxSxH, K (present_k): BxNxTxH, Q*K': BxNxSxT float one = 1.0f; float zero = 0.f; @@ -913,22 +445,31 @@ Status QkvToContext( cublasSetStream(cublas, stream); - DUMP_TENSOR_D("q[BNSH]", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k[BNSH]", k, batch_size, num_heads, total_sequence_length, qk_head_size); + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q[BNSH]", data.q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k[BNSH]", data.k, batch_size, num_heads, total_sequence_length, qk_head_size); + + const int present_sequence_length = parameters.past_present_share_buffer + ? parameters.max_sequence_length + : total_sequence_length; + const int present_size_per_batch_k = present_sequence_length * qk_head_size; + const int present_size_per_batch_v = present_sequence_length * v_head_size; + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_T, CUBLAS_OP_N, total_sequence_length, sequence_length, qk_head_size, - &alpha, k, qk_head_size, present_size_per_batch_k, - q, qk_head_size, sequence_length * qk_head_size, - &zero, scratch1, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop)); + &alpha, data.k, qk_head_size, present_size_per_batch_k, + data.q, qk_head_size, sequence_length * qk_head_size, + &zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop)); - DUMP_TENSOR_D("Q", q, batch_size * num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("K", k, batch_size * num_heads, qk_head_size, sequence_length); - DUMP_TENSOR_D("QK", scratch1, batch_size * num_heads, sequence_length, total_sequence_length); + DUMP_TENSOR_D("Q", data.q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("K", data.k, batch_size, num_heads, qk_head_size, sequence_length); + DUMP_TENSOR_D("QK", data.scratch, batch_size, num_heads, sequence_length, total_sequence_length); + constexpr size_t element_size = sizeof(T); const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, total_sequence_length); - T* scratch2 = scratch1 + (bytes / element_size); + T* scratch2 = data.scratch + (bytes / element_size); // Apply softmax and store result R to scratch2: BxNxSxT if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask @@ -938,13 +479,15 @@ Status QkvToContext( const TransformerOptions* options = TransformerOptions::GetInstance(); bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax(); - T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score for persistent softmax. + // replace Q*K' in place with masked score for persistent softmax. + T* persistent_softmax_workspace = data.scratch; ORT_RETURN_IF_ERROR( - ComputeSoftmaxWithRawMask(ort_stream, total_sequence_length, sequence_length, batch_size, num_heads, - mask_index, nullptr, data.relative_position_bias, parameters.broadcast_res_pos_bias, - scratch1, scratch2, parameters.is_unidirectional, scale, mask_dimension, - parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace, - mask_filter_value)); + ComputeSoftmaxWithRawMask( + ort_stream, total_sequence_length, sequence_length, batch_size, num_heads, + mask_index, nullptr, data.relative_position_bias, parameters.broadcast_res_pos_bias, + data.scratch, scratch2, parameters.is_unidirectional, scale, mask_dimension, + parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace, + parameters.mask_filter_value)); } else if (nullptr != mask_index) { // 1d mask index assert(mask_index_dims.size() == 1); // mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions. @@ -952,274 +495,123 @@ Status QkvToContext( ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D( stream, total_sequence_length, sequence_length, batch_size, num_heads, mask_index, mask_start, data.relative_position_bias, parameters.broadcast_res_pos_bias, - scratch1, scratch2, parameters.is_unidirectional)); + data.scratch, scratch2, parameters.is_unidirectional)); } else { // no mask ORT_RETURN_IF_ERROR( - ComputeSoftmax(stream, total_sequence_length, sequence_length, batch_size, num_heads, data.relative_position_bias, - parameters.broadcast_res_pos_bias, scratch1, scratch2, parameters.is_unidirectional)); + ComputeSoftmax( + stream, total_sequence_length, sequence_length, batch_size, num_heads, data.relative_position_bias, + parameters.broadcast_res_pos_bias, data.scratch, scratch2, parameters.is_unidirectional)); } - DUMP_TENSOR_D("Softmax", scratch2, batch_size * num_heads, sequence_length, total_sequence_length); - DUMP_TENSOR_D("V", v, batch_size * num_heads, sequence_length, v_head_size); + DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length); + DUMP_TENSOR_D("V", data.v, batch_size, num_heads, sequence_length, v_head_size); // compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v - T* temp_output = qkv; + T* temp_output = data.q; CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, v_head_size, sequence_length, total_sequence_length, - &one, v, v_head_size, present_size_per_batch_v, + &one, data.v, v_head_size, present_size_per_batch_v, scratch2, total_sequence_length, sequence_length * total_sequence_length, &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop)); // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, temp_output, data.output); - DUMP_TENSOR("unfused output", data.output, batch_size * sequence_length, num_heads, v_head_size); + device_prop.maxThreadsPerBlock, false, temp_output, data.output); + DUMP_TENSOR("unfused output", data.output, batch_size, sequence_length, num_heads, v_head_size); return result; } template -Status DecoderQkvToContext( +Status QkvToContext( const cudaDeviceProp& device_prop, - Stream* ort_stream, cublasHandle_t& cublas, - const size_t element_size, - const int batch_size, - const int sequence_length, - const int kv_sequence_length, - const int num_heads, - const int head_size, - const bool static_kv, - const bool use_past, - const bool has_layer_state, - const bool has_key_padding_mask, - const float mask_filter_value, - const T* gemm_query_buffer, - const T* gemm_kv_buffer, - const bool* key_padding_mask, - const T* key_cache, - const T* value_cache, - T* qkv_buffer, - T* workspace_buffer, - T* output, - T* new_key_cache, - T* new_value_cache) { + Stream* ort_stream, + contrib::AttentionParameters& parameters, + AttentionData& data) { + auto stream = static_cast(ort_stream->GetHandle()); const int max_threads_per_block = device_prop.maxThreadsPerBlock; - const int BN = batch_size * num_heads; - const int BHN = BN * head_size; - const int BNS = BN * sequence_length; - const int k_buffer_offset = sequence_length * BHN; - const int v_buffer_offset = (sequence_length + kv_sequence_length) * BHN; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int total_sequence_length = parameters.total_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + void* fused_runner = data.fused_runner; - T* temp_qkv_buffer = workspace_buffer; - auto stream = static_cast(ort_stream->GetHandle()); + // At most one fused kernel is enabled. + assert((int(data.use_flash_attention) + + int(data.use_memory_efficient_attention) + + int(fused_runner != nullptr) + + int(data.fused_cross_attention_kernel != nullptr)) <= 1); - const T* q = qkv_buffer; - // transpose q and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_query_buffer, qkv_buffer)); - - const T* k = qkv_buffer + k_buffer_offset; - const T* v = qkv_buffer + v_buffer_offset; - if (!has_layer_state || !use_past) { - if (!static_kv) { - // transpose kv and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); - } else { - // transpose kv and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); - } - } else { - if (!static_kv) { - // transpose kv and copy them to temp_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer)); - // concat cache-k with k and copy to qkv_buffer - if (nullptr != key_cache) { - ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, - sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, 1, - key_cache, - temp_qkv_buffer, - qkv_buffer + k_buffer_offset)); - } - // concat cache-v with v and copy to qkv_buffer - if (nullptr != value_cache) { - ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, - sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, 1, - value_cache, - temp_qkv_buffer + k_buffer_offset, - qkv_buffer + v_buffer_offset)); - } + ORT_RETURN_IF_ERROR(PrepareQkv(parameters, data, stream, max_threads_per_block)); + + if (!parameters.past_present_share_buffer) { + ORT_RETURN_IF_ERROR(ConcatPastToPresent(batch_size, num_heads, qk_head_size, v_head_size, + sequence_length, total_sequence_length, parameters.pass_past_in_kv, + stream, max_threads_per_block, data)); + + } else { // past_present_share_buffer + assert(qk_head_size == v_head_size); + assert(data.fused_cross_attention_kernel == nullptr); + assert(nullptr == fused_runner || parameters.is_unidirectional); + assert(data.gemm_buffer != nullptr); + assert(!data.use_memory_efficient_attention); + assert(!data.use_flash_attention); + assert(data.has_qkv_workspace); + + if (nullptr != data.past_key || nullptr != data.present_key) { + // TODO: support this case. + ORT_THROW("buffer sharing for no bias case between past and present is not supported yet."); } - } - if (has_layer_state) { - if (use_past && static_kv) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, key_cache, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, value_cache, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); - } else { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, k, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, v, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); + if (data.present != data.past) { + // For easy testing. Production should better avoid this path. + int64_t kv_size = 2LL * (int64_t)batch_size * num_heads * parameters.max_sequence_length * qk_head_size; + cudaMemcpyAsync(data.present, data.past, kv_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); } - } - // scratch1: BxNxSxL buffer - // scratch2: BxNxSxL buffer - // scratch3: BxNxSxH buffer - T* scratch1 = temp_qkv_buffer + 3 * BHN * sequence_length; - T* scratch2 = scratch1 + BNS * kv_sequence_length; - T* scratch3 = scratch2 + BNS * kv_sequence_length; - - // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxL - // Q: BxNxSxH, K (present_k): BxNxLxH, Q*K': BxNxSxL - const float rsqrt_head_size = 1.f / sqrt(static_cast(head_size)); - const int temp_matrix_size = sequence_length * kv_sequence_length; - float one = 1.0f; - float zero = 0.f; + // For fused causal, bias has been added to gemm_buffer. + const T* bias = (nullptr != fused_runner && parameters.is_unidirectional) ? nullptr : data.bias; - float alpha = rsqrt_head_size; - const int strideA = kv_sequence_length * head_size; - const int strideB = sequence_length * head_size; - if (use_past && static_kv) { - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, - kv_sequence_length, sequence_length, head_size, - &alpha, key_cache, head_size, strideA, - q, head_size, strideB, - &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); - } else { - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, - kv_sequence_length, sequence_length, head_size, - &alpha, k, head_size, strideA, - q, head_size, strideB, - &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + // append last k v to present + ORT_RETURN_IF_ERROR(LaunchAddBiasTransAppendKvToPresent( + stream, parameters.max_sequence_length, parameters.past_sequence_length, sequence_length, + batch_size, qk_head_size, num_heads, max_threads_per_block, + bias, data.gemm_buffer, data.present)); + + data.k = data.present; + data.v = data.present + batch_size * num_heads * parameters.max_sequence_length * qk_head_size; } - constexpr bool is_unidirectional = false; - const T* add_before_softmax = nullptr; - if (has_key_padding_mask) { - constexpr int mask_dimension = 2; - constexpr int max_sequence_length = 0; - ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask(ort_stream, kv_sequence_length, sequence_length, batch_size, - num_heads, nullptr, key_padding_mask, add_before_softmax, - false /*broadcast rpb*/, scratch1, scratch2, is_unidirectional, - 1.0f, mask_dimension, max_sequence_length, false, nullptr, - mask_filter_value)); - } else { - ORT_RETURN_IF_ERROR(ComputeSoftmax(stream, kv_sequence_length, sequence_length, batch_size, num_heads, - add_before_softmax, false /*broadcast rpb*/, scratch1, scratch2, - is_unidirectional)); + // Q, K and V are ready now + if (data.fused_cross_attention_kernel != nullptr) { + return FusedTrtCrossAttention(stream, parameters, data); } - // compute P*V (as V*P), and store in scratch3: BxNxSxH - if (use_past && static_kv) { - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_N, CUBLAS_OP_N, - head_size, sequence_length, kv_sequence_length, - &one, value_cache, head_size, strideA, - scratch2, kv_sequence_length, temp_matrix_size, - &zero, scratch3, head_size, strideB, BN, device_prop)); - } else { - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_N, CUBLAS_OP_N, - head_size, sequence_length, kv_sequence_length, - &one, v, head_size, strideA, - scratch2, kv_sequence_length, temp_matrix_size, - &zero, scratch3, head_size, strideB, BN, device_prop)); + // Run TRT fused attention. + if (nullptr != fused_runner) { + return FusedTrtSelfAttention(stream, parameters, data); } - // scratch3 is BxNxSxH, transpose to output SxBxNxH - return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, scratch3, output); -} + // For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation. + const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) + : parameters.scale; -Status LaunchDecoderAttentionKernel( - const cudaDeviceProp& device_prop, - Stream* stream, - cublasHandle_t& cublas, - const size_t element_size, - const int batch_size, - const int sequence_length, - const int kv_sequence_length, - const int num_heads, - const int head_size, - const bool static_kv, - const bool use_past, - const bool has_layer_state, - const bool has_key_padding_mask, - const float mask_filter_value, - const void* gemm_query_buffer, - const void* gemm_kv_buffer, - const bool* key_padding_mask, - const void* key_cache, - const void* value_cache, - void* qkv_buffer, - void* workspace_buffer, - void* output, - void* new_key_cache, - void* new_value_cache) { - if (element_size == 2) { - return DecoderQkvToContext( - device_prop, - stream, - cublas, - element_size, - batch_size, - sequence_length, - kv_sequence_length, - num_heads, - head_size, - static_kv, - use_past, - has_layer_state, - has_key_padding_mask, - mask_filter_value, - reinterpret_cast(gemm_query_buffer), - reinterpret_cast(gemm_kv_buffer), - key_padding_mask, - reinterpret_cast(key_cache), - reinterpret_cast(value_cache), - reinterpret_cast(qkv_buffer), - reinterpret_cast(workspace_buffer), - reinterpret_cast(output), - reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); - } else { - return DecoderQkvToContext( - device_prop, - stream, - cublas, - element_size, - batch_size, - sequence_length, - kv_sequence_length, - num_heads, - head_size, - static_kv, - use_past, - has_layer_state, - has_key_padding_mask, - mask_filter_value, - reinterpret_cast(gemm_query_buffer), - reinterpret_cast(gemm_kv_buffer), - key_padding_mask, - reinterpret_cast(key_cache), - reinterpret_cast(value_cache), - reinterpret_cast(qkv_buffer), - reinterpret_cast(workspace_buffer), - reinterpret_cast(output), - reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); +#if USE_FLASH_ATTENTION + if (data.use_flash_attention) { + return FlashAttention(device_prop, stream, parameters, data, scale); } +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + if (data.use_memory_efficient_attention) { + return EfficientAttention(device_prop, stream, parameters, data, scale); + } +#endif + + return UnfusedAttention(device_prop, cublas, ort_stream, parameters, data, scale); } // Template Instantiation diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 5c63a8d8a8..3e78978c3c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -2,11 +2,12 @@ // Licensed under the MIT License. #pragma once -#include "core/providers/cuda/shared_inc/cuda_utils.h" + #include #include -#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/common/gsl.h" #include "core/framework/allocator.h" +#include "contrib_ops/cpu/bert/attention_common.h" namespace onnxruntime { namespace contrib { @@ -43,43 +44,63 @@ size_t GetAttentionWorkspaceSize( size_t kv_sequence_length, size_t total_sequence_length, void* fused_runner, + bool use_flash_attention, bool use_fused_cross_attention, bool use_memory_efficient_attention); template struct AttentionData { - T* gemm_buffer; - const T* bias; + T* gemm_buffer = nullptr; + const T* bias = nullptr; - const T* query; - const T* key; - const T* value; - const int* mask_index; + const T* query = nullptr; + const T* key = nullptr; + const T* value = nullptr; + const int* mask_index = nullptr; gsl::span mask_index_dims; - const T* past; - const T* past_key; - const T* past_value; - const T* relative_position_bias; - - bool has_qkv_workspace; - T* workspace; - T* temp_k_workspace; - T* temp_v_workspace; - - T* output; - T* present; - T* present_key; - T* present_value; - - void* fused_runner; - const void* fused_cross_attention_kernel; - - bool use_memory_efficient_attention; - - mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache; - mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache; + const T* past = nullptr; + const T* past_key = nullptr; + const T* past_value = nullptr; + const T* relative_position_bias = nullptr; + + bool has_qkv_workspace = false; + T* workspace = nullptr; + T* temp_k_workspace = nullptr; + T* temp_v_workspace = nullptr; + + T* output = nullptr; + T* present = nullptr; + T* present_key = nullptr; + T* present_value = nullptr; + + void* fused_runner = nullptr; + const void* fused_cross_attention_kernel = nullptr; + + bool use_flash_attention = false; + bool use_memory_efficient_attention = false; + + mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache = nullptr; + mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache = nullptr; + + // Intermediate data + T* q = nullptr; + T* k = nullptr; + T* v = nullptr; + T* scratch = nullptr; + AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + + // Flash buffers + T* softmax_lse = nullptr; + T* softmax_lse_accum = nullptr; + T* out_accum = nullptr; }; +template +Status PrepareQkv(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block); + template Status QkvToContext( const cudaDeviceProp& device_prop, @@ -88,33 +109,6 @@ Status QkvToContext( contrib::AttentionParameters& parameters, AttentionData& data); -Status LaunchDecoderAttentionKernel( - const cudaDeviceProp& prop, // Device Properties - Stream* stream, // ORT Stream - cublasHandle_t& cublas, // Cublas handle - const size_t element_size, // Element size of input tensor - const int batch_size, // Batch size (B) - const int sequence_length, // Sequence length (S) - const int kv_sequence_length, // Key/Value/Cache sequence length - const int num_heads, // Number of attention heads (N) - const int head_size, // Hidden size per head (H) - const bool static_kv, // Whether cross attention or not - const bool use_past, // Whether use cache or not - const bool has_layer_state, // Whether output cache or not - const bool has_key_padding_mask, // Whether use key_padding_mask or not - const float mask_filter_value, // Mask filter value - const void* gemm_query_buffer, // Query buffer - const void* gemm_kv_buffer, // Key and value buffer - const bool* key_padding_mask, // Key padding mask - const void* key_cache, // Input key cache - const void* value_cache, // Input value cache - void* qkv_buffer, // Temporary buffer - void* workspace_buffer, // Temporary buffer - void* output, // Output tensor - void* new_key_cache, // New_key_cache tensor - void* new_value_cache // New_value_cache tensor -); - // BxNxSxH => BxSxNxH or SxBxNxH (reversed_bs is true) Status LaunchTransCtx(cudaStream_t stream, const int sequence_length, const int batch_size, const int head_size, const int num_heads, @@ -159,33 +153,32 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream, const half* tensor_add, half* tensor_out); -Status LaunchConcatPastToPresent(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const float* past, - const float* k_v, - float* present); - -Status LaunchConcatPastToPresent(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const half* past, - const half* k_v, - half* present); +template +Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, + int sequence_length, int total_sequence_length, bool pass_past_in_kv, + cudaStream_t stream, + int max_threads_per_block, + AttentionData& data); + +template +Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, + const int max_sequence_length, + const int past_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const T* biases, + const T* qkv_buffer, + T* present); template Status LaunchStridedCopy(cudaStream_t stream, const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) T* out, longlong4 out_strides, // coord (b,n,s,h) int max_threads_per_block); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu new file mode 100644 index 0000000000..89be0f1115 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu @@ -0,0 +1,466 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/bert/attention_impl.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void ConcatTensorToTensor(const int tensor_add_sequence_length, + const T* tensor_in, + const T* tensor_add, + T* tensor_out) { + const int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + const int chunk_id = blockIdx.z; + + const int all_sequence_length = gridDim.x; + const int batch_size = gridDim.y; + const int num_heads = blockDim.y; + const int H = blockDim.x; + + // K: number of identical tensors + // tensor_in: K x BxNxPxH + // tensor_add: K x BxNxLxH + // tensor_out: K x BxNxTxH, where T = P + L + const int tensor_in_sequence_length = all_sequence_length - tensor_add_sequence_length; + + const int present_SH = all_sequence_length * H; + const int present_NSH = num_heads * present_SH; + int out_offset = b * present_NSH + n * present_SH + s * H + h + chunk_id * (present_NSH * batch_size); + if (s < tensor_in_sequence_length) { + const int past_SH = tensor_in_sequence_length * H; + const int past_NSH = num_heads * past_SH; + const int in_offset = b * past_NSH + n * past_SH + s * H + h + chunk_id * (past_NSH * batch_size); + tensor_out[out_offset] = tensor_in[in_offset]; + } else if (s < all_sequence_length) { + const int SH = tensor_add_sequence_length * H; + const int NSH = num_heads * SH; + const int in_offset = b * NSH + n * SH + (s - tensor_in_sequence_length) * H + h + chunk_id * (NSH * batch_size); + tensor_out[out_offset] = tensor_add[in_offset]; + } +} + +template +__global__ void ConcatTensorToTensorLarge(const int tensor_add_sequence_length, + const int H, + const T* tensor_in, + const T* tensor_add, + T* tensor_out) { + // Use when (H*)*num_heads > 1024 + int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + const int chunk_id = blockIdx.z; + + const int all_sequence_length = gridDim.x; + const int batch_size = gridDim.y; + const int num_heads = blockDim.y; + const int stride = blockDim.x; + + // K: number of identical tensor + // tensor_in: K x BxNxPxH + // tensor_add: K x BxNxLxH + // tensor_out: K x BxNxTxH + const int tensor_in_sequence_length = all_sequence_length - tensor_add_sequence_length; + + const int present_SH = all_sequence_length * H; + const int present_NSH = num_heads * present_SH; + while (h < H) { + int out_offset = b * present_NSH + n * present_SH + s * H + h + chunk_id * (present_NSH * batch_size); + if (s < tensor_in_sequence_length) { + const int past_SH = tensor_in_sequence_length * H; + const int past_NSH = num_heads * past_SH; + const int in_offset = b * past_NSH + n * past_SH + s * H + h + chunk_id * (past_NSH * batch_size); + tensor_out[out_offset] = tensor_in[in_offset]; + } else if (s < all_sequence_length) { + const int SH = tensor_add_sequence_length * H; + const int NSH = num_heads * SH; + const int in_offset = b * NSH + n * SH + (s - tensor_in_sequence_length) * H + h + chunk_id * (NSH * batch_size); + tensor_out[out_offset] = tensor_add[in_offset]; + } + + h += stride; + } +} + +Status LaunchConcatTensorToTensor(cudaStream_t stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const int matrix_num, + const float* tensor_in, + const float* tensor_add, + float* tensor_out) { + const dim3 grid(all_sequence_length, batch_size, matrix_num); + if (0 == (head_size & 1)) { + const int H = head_size / 2; + if (H * num_heads <= max_threads_per_block) { + const dim3 block(H, num_heads, 1); + ConcatTensorToTensor<<>>(sequence_length, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast(tensor_out)); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + ConcatTensorToTensorLarge<<>>(sequence_length, + H, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast(tensor_out)); + } + } else { + if (head_size * num_heads <= max_threads_per_block) { + const dim3 block(head_size, num_heads, 1); + ConcatTensorToTensor<<>>(sequence_length, tensor_in, tensor_add, tensor_out); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + ConcatTensorToTensorLarge<<>>(sequence_length, + head_size, + tensor_in, + tensor_add, + tensor_out); + } + } + return CUDA_CALL(cudaGetLastError()); +} + +Status LaunchConcatTensorToTensor(cudaStream_t stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const int matrix_num, + const half* tensor_in, + const half* tensor_add, + half* tensor_out) { + const dim3 grid(all_sequence_length, batch_size, matrix_num); + if (0 == (head_size % 4)) { + const int H = head_size / 4; + if (H * num_heads <= max_threads_per_block) { + const dim3 block(H, num_heads, 1); + ConcatTensorToTensor<<>>(sequence_length, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast(tensor_out)); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + ConcatTensorToTensorLarge<<>>(sequence_length, + H, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast(tensor_out)); + } + } else if (0 == (head_size & 1)) { + const int H = head_size / 2; + if (H * num_heads <= max_threads_per_block) { + const dim3 block(H, num_heads, 1); + ConcatTensorToTensor<<>>(sequence_length, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast(tensor_out)); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + ConcatTensorToTensorLarge<<>>(sequence_length, + H, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast(tensor_out)); + } + } else { // this should be an "odd" case. probably not worth catching it in the half2 kernel. + if (head_size * num_heads <= max_threads_per_block) { + const dim3 block(head_size, num_heads, 1); + ConcatTensorToTensor<<>>(sequence_length, tensor_in, tensor_add, tensor_out); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + ConcatTensorToTensorLarge<<>>(sequence_length, + head_size, + tensor_in, + tensor_add, + tensor_out); + } + } + return CUDA_CALL(cudaGetLastError()); +} + +Status LaunchConcatPastToPresent(cudaStream_t stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const float* past, + const float* k_v, + float* present) { + return LaunchConcatTensorToTensor( + stream, + all_sequence_length, + sequence_length, + batch_size, + head_size, + num_heads, + max_threads_per_block, + 2, + past, + k_v, + present); +} + +Status LaunchConcatPastToPresent(cudaStream_t stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const half* past, + const half* k_v, + half* present) { + return LaunchConcatTensorToTensor( + stream, + all_sequence_length, + sequence_length, + batch_size, + head_size, + num_heads, + max_threads_per_block, + 2, + past, + k_v, + present); +} + +#ifndef USE_ROCM // exclude the following from hipify since they are not used in ROCM EP + +template +Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, + int sequence_length, int total_sequence_length, bool pass_past_in_kv, + cudaStream_t stream, + int max_threads_per_block, + AttentionData& data) { + // Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length. + // past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH) + // past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH) + // When there is past state, the head size for Q/K/V shall be same: H == H_v. + + if (nullptr != data.present) { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); + + ORT_RETURN_IF_ERROR( + LaunchConcatPastToPresent( + stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, data.past, data.k, data.present)); + + // Update pointers to present_k and present_v. + data.k = data.present; + data.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size; + } else if (nullptr != data.past_key || nullptr != data.present_key) { + if (nullptr != data.past_key && nullptr == data.present_key) { + data.k = const_cast(data.past_key); + data.v = const_cast(data.past_value); + } else if (nullptr == data.past_key && nullptr != data.present_key) { + if (data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) { + data.k = data.present_key; + data.v = data.present_value; + } else { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + data.k = data.temp_k_workspace; + data.v = data.temp_v_workspace; + } + } else if (pass_past_in_kv) { + // past_key and past_value are used directly as key and value in attention computations + data.k = const_cast(data.past_key); + data.v = const_cast(data.past_value); + + // This path has a memory copy from past_key and past_value to present_key and present_value + // Avoid this path since the memory copy is unnecessary because past_key == present_key and + // past_value == present_value + int64_t k_size = (int64_t)batch_size * num_heads * total_sequence_length * qk_head_size; + int64_t v_size = (int64_t)batch_size * num_heads * total_sequence_length * v_head_size; + cudaMemcpyAsync(data.present_key, data.past_key, k_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync(data.present_value, data.past_value, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); + } else { + ORT_RETURN_IF_ERROR( + LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, + batch_size, qk_head_size, num_heads, + max_threads_per_block, 1, data.past_key, data.k, data.present_key)); + ORT_RETURN_IF_ERROR( + LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, + batch_size, v_head_size, num_heads, + max_threads_per_block, 1, data.past_value, data.v, data.present_value)); + // Update pointers to present_k and present_v. + data.k = data.present_key; + data.v = data.present_value; + } + } + + return CUDA_CALL(cudaGetLastError()); +} + +// Template Instantiation +template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, + int sequence_length, int total_sequence_length, bool pass_past_in_kv, + cudaStream_t stream, + int max_threads_per_block, + AttentionData& data); + +template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, + int sequence_length, int total_sequence_length, bool pass_past_in_kv, + cudaStream_t stream, + int max_threads_per_block, + AttentionData& data); + +// ---------------------------------------------------------------------------------- +// Below kernels are for past and present sharing buffer +// ---------------------------------------------------------------------------------- + +template +__global__ void AddBiasTransAppendKvToPresentSmall( + const T* qkv, const T* biases, T* present, + const int head_size, const int past_sequence_length, const int max_sequence_length) { + // Input: BxSxMxNxH (Format 1) + // Output: (2, B, N, [P..P+S) of MaxS, H), + // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + const int N = blockDim.y; + const int S = gridDim.x; + const int B = gridDim.y; + + constexpr int M = 3; // Matrix count in qkv + const int m = blockIdx.z + 1; // k = 1, v = 2 + + const int NH = N * head_size; + const int NHS = NH * S; + + qkv += (n * head_size + (s * M + m) * NH + b * M * NHS); + if (biases) { + biases += (m * NH + n * head_size); + } + + const int MsH = max_sequence_length * head_size; + const int NMsH = N * MsH; + const int BNMsH = B * NMsH; + present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH); + + for (int h = threadIdx.x; h < head_size; h += blockDim.x) { + T bias = (biases ? biases[h] : (T)0.0f); + present[h] = qkv[h] + bias; + } +} + +template +__global__ void AddBiasTransAppendKvToPresent( + const T* qkv, const T* biases, T* present, + const int head_size, const int past_sequence_length, const int max_sequence_length) { + // Input: BxSxMxNxH (Format 1) + // Output: (2, B, N, [P..P+S) of MaxS, H), + // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size + const int n = blockIdx.x; + const int s = blockIdx.y; + const int b = (blockIdx.z >> 1); + const int N = gridDim.x; + const int S = gridDim.y; + const int B = (gridDim.z >> 1); + + constexpr int M = 3; // Matrix count in qkv + const int m = (blockIdx.z & 0x1) + 1; // k = 1, v = 2 + + const int NH = N * head_size; + const int NHS = NH * S; + + qkv += (n * head_size + (s * M + m) * NH + b * M * NHS); + if (biases) { + biases += (m * NH + n * head_size); + } + + const int MsH = max_sequence_length * head_size; + const int NMsH = N * MsH; + const int BNMsH = B * NMsH; + present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH); + + for (int h = threadIdx.x; h < head_size; h += blockDim.x) { + T bias = (biases ? biases[h] : (T)0.0f); + present[h] = qkv[h] + bias; + } +} + +// qkv buffer is merged tensor of shape (B,S,3,N,H), k v is the second/third of the 3. +// bias is of shape (3, NxH) or nullptr +// append to present of (2, B, N, (P..T) of M, H), +template +Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, + const int max_sequence_length, + const int past_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const T* biases, + const T* qkv_buffer, + T* present) { + assert(head_size <= (1 << 30)); + + int64_t nh = (int64_t)head_size * num_heads; + if (nh <= max_threads_per_block) { + const dim3 grid(sequence_length, batch_size, 2); // 2 for k and v + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + + AddBiasTransAppendKvToPresentSmall<<>>( + qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length); + } else { + const dim3 grid(num_heads, sequence_length, batch_size * 2); // 2 for k and v + const dim3 block(std::min(head_size, max_threads_per_block), 1, 1); + AddBiasTransAppendKvToPresent<<>>( + qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length); + } + + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, + const int max_sequence_length, + const int total_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const float* bias, + const float* qkv_buffer, + float* present); + +template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, + const int max_sequence_length, + const int total_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const half* bias, + const half* qkv_buffer, + half* present); +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu new file mode 100644 index 0000000000..5c65a30918 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -0,0 +1,492 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/bert/attention_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "contrib_ops/cuda/bert/add_bias_transpose.h" +#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + const bool past_present_share_buffer = parameters.past_present_share_buffer; + void* fused_runner = data.fused_runner; + bool use_flash_or_efficient_attention = data.use_flash_attention || data.use_memory_efficient_attention; + + T* qkv = data.workspace; + + bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); + bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); + + if (data.bias == nullptr) { + assert(nullptr == fused_runner); + // For quantized attention, bias has been added so only need transpose here. + // gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH + assert(qk_head_size == v_head_size); + int matrix_to_trans = (past_present_share_buffer ? 1 : 3); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.gemm_buffer, qkv, 3)); + qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } else { + // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2) + // For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3) + // For unfused kernel, transpose to 3xBxNxSxH (format 1) + // For fused causal kernel, use format 1 since we need have K and V to update present state, + // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel. + const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1)); + qkv_format = use_fused_kernel + ? AttentionQkvFormat::QKV_BSN3H + : (use_flash_or_efficient_attention + ? AttentionQkvFormat::Q_K_V_BSNH + : (use_fused_causal + ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH + : AttentionQkvFormat::Q_K_V_BNSH)); + + // For fused causal, we will update gemm_buffer with bias directly. + T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr; + + int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3); + // format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v + // format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H) + LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias, + 3, parameters.do_rotary, parameters.past_sequence_length); + } + return Status::OK(); +} + +// For MultiHeadAttention with past state +template +Status PrepareQkv_MHA_WithPast(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + DUMP_TENSOR_INIT(); + + if (data.bias == nullptr) { + // Below logic does not support fused attention with past without bias + // When there is past state, the format shall be BxNxSxH, so we disable fused attention when there is past. + + // cross attention with past state + if (data.past_key != nullptr && data.present_key == nullptr) { + assert(data.past_value != nullptr); + assert(data.query != nullptr); + assert(data.key == nullptr); + assert(data.value == nullptr); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.query, q)); + } + // cross attention with present state or self attention with present state + else if (data.past_key == nullptr && data.present_key != nullptr) { + assert(data.past_value == nullptr); + assert(data.present_value != nullptr); + assert(data.query != nullptr); + assert(data.key != nullptr); + assert(data.value != nullptr); + + // TODO: supporting packed qkv for self attention may benefit performance + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.query, q)); + + // TODO: supporting packed kv for cross attention may benefit performance + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.key, data.present_key)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, + max_threads_per_block, false, data.value, data.present_value)); + } + // self attention with past and present state + else { + assert(data.past_key != nullptr); + assert(data.past_value != nullptr); + assert(data.present_key != nullptr); + assert(data.present_value != nullptr); + assert(data.query != nullptr); + assert(data.key != nullptr); + assert(data.value != nullptr); + // TODO: supporting packed qkv for self attention may benefit performance + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.query, q)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.key, k)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, + max_threads_per_block, false, data.value, v)); + } + qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } +#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION + // When past_key/past_value are inputted directly as key/value and there is no present_key/present_value + else if ((data.use_memory_efficient_attention || data.use_flash_attention) && + data.past_key != nullptr && + data.past_value != nullptr && + parameters.pass_past_in_kv) { + // Transpose past_key and past_value to use memory efficient attention + + // past_key (BxNxSxH) => temp_k_workspace (BxSxNxH) + ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.past_key, data.temp_k_workspace)); + // past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) + ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.past_value, data.temp_v_workspace)); + + // query => q, temp_k_workspace => k, temp_v_workspace => v + LaunchAddBias(stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, + num_heads, qk_head_size, v_head_size, + data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v); + + DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + + data.past_key = nullptr; + data.past_value = nullptr; + } + // When there is no past_key/past_value and there is present_key/present_value + // (e.g. get initial kv to use as past_kv in the next iteration) + else if ((data.use_memory_efficient_attention || data.use_flash_attention) && + data.present_key != nullptr && + data.present_value != nullptr) { + // Use memory efficient attention kernel + LaunchAddBias(stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, + num_heads, qk_head_size, v_head_size, + data.bias, data.query, data.key, data.value, q, data.temp_k_workspace, data.temp_v_workspace); + + // temp_k_workspace (BxSxNxH) => present_k (BxNxSxH) + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.temp_k_workspace, data.present_key)); + + // temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v) + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, + max_threads_per_block, false, data.temp_v_workspace, data.present_value)); + + DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", data.temp_v_workspace, batch_size, kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } +#endif + else { + // Use unfused kernel for Q, use unfused kernel for K and V if needed + constexpr int format = 0; + // Query (BxSxNxH) => Q (BxNxSxH) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, q, + true, -1); + + if (!parameters.pass_past_in_kv) { + T* k_dest = (data.past_key == nullptr && data.present_key != nullptr) ? data.present_key : k; + T* v_dest = (data.past_value == nullptr && data.present_value != nullptr) ? data.present_value : v; + + // Key (BxLxNxH) => K (BxNxLxH) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, data.bias + num_heads * qk_head_size, k_dest, + true, -1); + + // Value (BxLxNxH_v) => V (BxNxLxH_v) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, data.bias + 2 * num_heads * qk_head_size, v_dest, + true, -1); + + DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k(BNSH)", k_dest, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("v(BNSH)", v_dest, batch_size, num_heads, kv_sequence_length, v_head_size); + } + qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + return Status::OK(); +} + +// For MultiHeadAttention without past state, with packed QKV inputs +template +Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + void* fused_runner = data.fused_runner; + + T* qkv = data.workspace; + + bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); + + assert(data.bias == nullptr); + assert(qk_head_size == v_head_size); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size); + + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // unpack qkv to BSNH. Note that there is no bias so we need not output query to q. + constexpr int format = 4; + T* qkv_add_bias = nullptr; + LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, qkv, + true, v_head_size, qkv_add_bias, 3); + DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size, sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else { + if (!use_fused_kernel) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, NOT_IMPLEMENTED, + "packed QKV format is not implemented for current GPU. Please disable it in fusion options."); + } + + qkv_format = AttentionQkvFormat::QKV_BSN3H; + } + return Status::OK(); +} + +// For MultiHeadAttention without past state, with packed KV inputs +template +Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + // TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint. + // CheckInputs verified this constraint. + assert(data.bias == nullptr); + assert(qk_head_size == v_head_size); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size); + + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // unpack kv to BSNH. Note that there is no bias so we need not output query to q. + constexpr int format = 4; + T* qkv_add_bias = nullptr; + const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size); + LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, kv_bias, k, + true, v_head_size, qkv_add_bias, 2); + DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else { + if (data.fused_cross_attention_kernel == nullptr) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, NOT_IMPLEMENTED, + "packed KV format is not implemented for current GPU. Please disable packed kv in fusion options."); + } + + qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + } + return Status::OK(); +} + +// For MultiHeadAttention without past state, with Q, K and V inputs +template +Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + void* fused_runner = data.fused_runner; + + T* qkv = data.workspace; + + bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); + bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); + + // gemm_buffer == nullptr and not packed + assert(data.query != nullptr && data.key != nullptr && data.value != nullptr); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("query", data.query, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("key", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("value", data.value, batch_size, kv_sequence_length, num_heads, v_head_size); + +#if DUMP_TENSOR_LEVEL > 1 + if (data.bias != nullptr) { + DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size); + DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); + DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); + } +#endif + + if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) { + DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, + num_heads, sequence_length, kv_sequence_length); + } + + if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) { + DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1); + } + + if (data.fused_cross_attention_kernel != nullptr) { + assert(qk_head_size == v_head_size); + + // For fused cross attention, besides adding bias, K and V needed to be packed: + // K (BxSxNxH), V (BxSxNxH) => BxSxNx2xH + LaunchAddBiasTransposeTrt( + stream, max_threads_per_block, + batch_size, sequence_length, + num_heads, qk_head_size, + data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length); + + qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + } +#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION + else if (data.use_memory_efficient_attention || data.use_flash_attention) { + LaunchAddBias(stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, + num_heads, qk_head_size, v_head_size, + data.bias, data.query, data.key, data.value, q, k, v); + + DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } +#endif + else if (use_fused_kernel) { + assert(qk_head_size == v_head_size); + + // Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H) + LaunchAddBiasTransposeTrt( + stream, max_threads_per_block, + batch_size, sequence_length, + num_heads, qk_head_size, + data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length); + DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size); + + qkv_format = AttentionQkvFormat::QKV_BSN3H; + } else { // unfused kernel + ORT_ENFORCE(!use_fused_causal, "MultiHeadAttention has not enabled fused causal"); + + // Query (BxSxNxH) => Q (BxNxSxH) + constexpr int format = 0; + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, q, + true, -1); + + // Key (BxLxNxH) => K (BxNxLxH) + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k, + true, -1); + + // Value (BxLxNxH_v) => K (BxNxLxH_v) + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v, + true, -1); + + DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("v(BNSH)", v, batch_size, num_heads, kv_sequence_length, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + return Status::OK(); +} + +template +Status PrepareQkv(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block) { + data.scratch = data.workspace; + if (data.has_qkv_workspace) { + const int size_per_batch_q = parameters.sequence_length * parameters.head_size; + const int size_per_batch_k = parameters.kv_sequence_length * parameters.head_size; + const int size_per_batch_v = parameters.kv_sequence_length * parameters.v_head_size; + const int batches = parameters.batch_size * parameters.num_heads; + const size_t elements_q = static_cast(batches) * static_cast(size_per_batch_q); + const size_t elements_k = static_cast(batches) * static_cast(size_per_batch_k); + const size_t elements_v = static_cast(batches) * static_cast(size_per_batch_v); + data.q = data.workspace; + data.k = data.workspace + elements_q; + data.v = data.k + elements_k; + data.scratch = data.v + elements_v; + } + + if (nullptr != data.gemm_buffer) { // Attention operator + ORT_RETURN_IF_ERROR(PrepareQkv_Attention(parameters, data, stream, max_threads_per_block, + data.qkv_format)); + } else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block, + data.q, data.k, data.v, data.qkv_format)); + } else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block, + data.q, data.k, data.v, data.qkv_format)); + } else if (data.value == nullptr) { // multihead attention operator, no past, packed kv + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block, + data.q, data.k, data.v, data.qkv_format)); + } else { // multihead attention operator, no past, separated Q/K/V inputs + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block, + data.q, data.k, data.v, data.qkv_format)); + } + + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + return Status::OK(); +} + +// Template Instantiation +template Status PrepareQkv( + contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block); + +template Status PrepareQkv( + contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu index e7d2255fb4..01ea02f48d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu @@ -18,7 +18,6 @@ limitations under the License. */ #include -#include #include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" diff --git a/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu b/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu index 2af748d8d4..32ed961a68 100644 --- a/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu +++ b/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu @@ -367,32 +367,32 @@ __global__ void __launch_bounds__(kMAX_THREADS_PER_BLOCK) const int* attention_masks, const int batch_size, const int sequence_length) { - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - const int batch_id = blockIdx.x; - const int* batch_mask = attention_masks + (batch_id * sequence_length); - const bool leftmost_non_zero = (batch_mask[0] != 0); - int biggest_position = 0; - - for (int i = threadIdx.x; i < sequence_length; i += blockDim.x) { - if (leftmost_non_zero == (batch_mask[i] != 0)) { - biggest_position = i; - } else { - break; - } + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + const int batch_id = blockIdx.x; + const int* batch_mask = attention_masks + (batch_id * sequence_length); + const bool leftmost_non_zero = (batch_mask[0] != 0); + int biggest_position = 0; + + for (int i = threadIdx.x; i < sequence_length; i += blockDim.x) { + if (leftmost_non_zero == (batch_mask[i] != 0)) { + biggest_position = i; + } else { + break; } + } - int last_leading_position = BlockReduce(temp_storage).Reduce(biggest_position, cub::Max(), blockDim.x); + int last_leading_position = BlockReduce(temp_storage).Reduce(biggest_position, cub::Max(), blockDim.x); - if (threadIdx.x == 0) { - int batch_offset = batch_id * sequence_length; - trt_mha_padding_offset[2 * batch_id] = batch_offset; - trt_mha_padding_offset[2 * batch_id + 1] = batch_offset + last_leading_position + 1; - if (batch_id == gridDim.x - 1) { - trt_mha_padding_offset[2 * batch_id + 2] = batch_offset + sequence_length; - } + if (threadIdx.x == 0) { + int batch_offset = batch_id * sequence_length; + trt_mha_padding_offset[2 * batch_id] = batch_offset; + trt_mha_padding_offset[2 * batch_id + 1] = batch_offset + last_leading_position + 1; + if (batch_id == gridDim.x - 1) { + trt_mha_padding_offset[2 * batch_id + 2] = batch_offset + sequence_length; } + } } // only support simple left padding with mask 0s on leading left, diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index 00fa265e11..db78722cc0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION #if defined(__GNUC__) #pragma GCC diagnostic push @@ -16,6 +16,133 @@ namespace onnxruntime { namespace contrib { namespace cuda { +template +struct RightPaddingBatchHook { + using scalar_t = typename AttentionKernel::scalar_t; + using accum_t = typename AttentionKernel::accum_t; + using lse_scalar_t = typename AttentionKernel::lse_scalar_t; + using output_t = typename AttentionKernel::output_t; + using output_accum_t = typename AttentionKernel::output_accum_t; + + static constexpr bool kSupportsDropout = AttentionKernel::kSupportsDropout; + static constexpr bool kSupportsBias = AttentionKernel::kSupportsBias; + static constexpr int kKeysPerBlock = AttentionKernel::kKeysPerBlock; + static constexpr bool kIsAligned = AttentionKernel::kIsAligned; + static constexpr bool kSingleValueIteration = AttentionKernel::kSingleValueIteration; + static constexpr int32_t kAlignLSE = AttentionKernel::kAlignLSE; // block size of backward + static constexpr bool kPreloadV = AttentionKernel::kPreloadV; + static constexpr bool kKeepOutputInRF = AttentionKernel::kKeepOutputInRF; + static constexpr bool kNeedsOutputAccumulatorBuffer = AttentionKernel::kNeedsOutputAccumulatorBuffer; + + template + static CUTLASS_DEVICE bool AdvanceToBlockForGQA(Params& p) { + auto batch_id = blockIdx.z; + auto head_id = blockIdx.y; + auto query_start = blockIdx.x * kQueriesPerBlock; + + auto lse_dim = ceil_div((int32_t)(p.num_queries), kAlignLSE) * kAlignLSE; + + // Advance to current batch - in case of different sequence lengths + if (p.seqlen_k_ptr) { + p.num_keys = p.seqlen_k_ptr[batch_id]; + } + + if (query_start >= p.num_queries) { + return false; + } + + // Advance to the current batch / head / query_start + p.query_ptr += batch_id * p.q_strideB + query_start * p.q_strideM + head_id * p.q_strideH; + p.key_ptr += batch_id * p.k_strideB + head_id * p.k_strideH; + p.value_ptr += batch_id * p.v_strideB + head_id * p.v_strideH; + p.output_ptr += int64_t(batch_id * p.num_queries) * p.o_strideM + int64_t(query_start) * p.o_strideM + head_id * p.head_dim_value; + + if (kSupportsBias && p.attn_bias_ptr != nullptr) { + p.attn_bias_ptr += (batch_id * p.bias_strideB) + (head_id * p.bias_strideH); + } + if (p.output_accum_ptr != nullptr) { + p.output_accum_ptr += int64_t(batch_id * p.num_queries) * (p.head_dim_value * p.num_heads) + + int64_t(query_start) * (p.head_dim_value * p.num_heads) + + head_id * p.head_dim_value; + } else { + // Accumulate directly in the destination buffer (eg for f32) + p.output_accum_ptr = (accum_t*)(p.output_ptr); + } + + if (p.logsumexp_ptr != nullptr) { + // lse[batch_id, head_id, query_start] + p.logsumexp_ptr += + batch_id * lse_dim * p.num_heads + head_id * lse_dim + query_start; + } + + // Custom masking + if (p.causal_diagonal_ptr) { + p.causal_diagonal_offset = p.causal_diagonal_ptr[batch_id]; + } + if (p.custom_mask_type == AttentionKernel::CausalFromBottomRight) { + p.causal_diagonal_offset += p.num_keys - p.num_queries; + } + if (p.custom_mask_type == AttentionKernel::CausalFromTopLeft || + p.custom_mask_type == AttentionKernel::CausalFromBottomRight) { + // the bottom row of the current block is query_start + kQueriesPerBlock + // the last active key is then query_start + causal_diagonal_offset + + // kQueriesPerBlock so num_keys is the min between actual num_keys and + // this to avoid extra computations + p.num_keys = cutlass::fast_min( + int32_t(query_start + p.causal_diagonal_offset + kQueriesPerBlock), + p.num_keys); + } + + p.num_queries -= query_start; + p.num_batches = 0; // no longer used after + + // If num_queries == 1, and there is only one key head we're wasting + // 15/16th of tensor core compute In that case : + // - we only launch kernels for head_id % kQueriesPerBlock == 0 + // - we iterate over heads instead of queries (strideM = strideH) + if (p.num_queries == 1 && p.k_strideH == 0 && p.v_strideH == 0) { + if (head_id % kQueriesPerBlock != 0) + return false; + p.q_strideM = p.q_strideH; + p.num_queries = p.num_heads; + p.num_heads = 1; // unused but here for intent + // remove causal since n_query = 1 + // otherwise, offset would change with head ! + p.custom_mask_type = AttentionKernel::NoCustomMask; + p.o_strideM = p.head_dim_value; + } + + // Make sure the compiler knows these variables are the same on all + // the threads of the warp. + p.query_ptr = warp_uniform(p.query_ptr); + p.key_ptr = warp_uniform(p.key_ptr); + p.value_ptr = warp_uniform(p.value_ptr); + if (kSupportsBias) { + p.attn_bias_ptr = warp_uniform(p.attn_bias_ptr); + } + p.output_ptr = warp_uniform(p.output_ptr); + p.output_accum_ptr = warp_uniform(p.output_accum_ptr); + p.logsumexp_ptr = warp_uniform(p.logsumexp_ptr); + p.num_queries = warp_uniform(p.num_queries); + p.num_keys = warp_uniform(p.num_keys); + p.num_heads = warp_uniform(p.num_heads); + p.head_dim = warp_uniform(p.head_dim); + p.head_dim_value = warp_uniform(p.head_dim_value); + p.o_strideM = warp_uniform(p.o_strideM); + p.custom_mask_type = warp_uniform(p.custom_mask_type); + return true; + } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched_impl_right_padding(typename AK::Params p) { + if (!RightPaddingBatchHook::AdvanceToBlockForGQA(p)) { + return; + } + AK::attention_kernel(p); +} + template void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { using Attention = AttentionKernel; @@ -51,28 +178,52 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { p.num_keys = params.kv_sequence_length; if (params.causal) { - p.custom_mask_type = Attention::CausalFromTopLeft; + p.custom_mask_type = Attention::CausalFromBottomRight; } - // Input format is BxSxNxH, output is BxSxNxH - p.q_strideH = params.qk_head_size; - p.k_strideH = params.qk_head_size; - p.v_strideH = params.v_head_size; - p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; - - p.q_strideM = params.num_heads * params.qk_head_size; - p.k_strideM = params.num_heads * params.qk_head_size; - p.v_strideM = params.num_heads * params.v_head_size; - p.o_strideM = params.num_heads * params.v_head_size; - p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; - - p.q_strideB = static_cast(p.q_strideM) * params.sequence_length; - p.k_strideB = static_cast(p.k_strideM) * params.kv_sequence_length; - p.v_strideB = static_cast(p.v_strideM) * params.kv_sequence_length; - p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + // We use max_sequence_length to calculate KV stride + if (params.is_kv_bsnh) { + // Input Q, K, V format is BxSxNxH, output is BxSxNxH + p.q_strideH = params.qk_head_size; + p.k_strideH = params.qk_head_size; + p.v_strideH = params.v_head_size; + p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; + + p.q_strideM = params.num_heads * params.qk_head_size; + p.k_strideM = params.num_heads * params.qk_head_size; + p.v_strideM = params.num_heads * params.v_head_size; + p.o_strideM = params.num_heads * params.v_head_size; + p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; + + p.q_strideB = static_cast(p.q_strideM) * params.sequence_length; + p.k_strideB = static_cast(p.k_strideM) * params.max_sequence_length; + p.v_strideB = static_cast(p.v_strideM) * params.max_sequence_length; + p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + } else { + // Input K, V format is BxNxSxH, Input Q is BxSxNxH, output is BxSxNxH + p.q_strideH = params.qk_head_size; + p.k_strideH = params.max_sequence_length * params.qk_head_size; + p.v_strideH = params.max_sequence_length * params.v_head_size; + p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; + + p.q_strideM = params.num_heads * params.qk_head_size; + p.k_strideM = params.qk_head_size; + p.v_strideM = params.v_head_size; + p.o_strideM = params.num_heads * params.v_head_size; + p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; + + p.q_strideB = params.num_heads * params.qk_head_size * params.sequence_length; + p.k_strideB = params.num_heads * params.qk_head_size * params.max_sequence_length; + p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length; + p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + } + } + + auto kernel_fn = attention_kernel_batched_impl; + if (params.has_custom_right_padding) { + kernel_fn = attention_kernel_batched_impl_right_padding; } - constexpr auto kernel_fn = attention_kernel_batched_impl; int smem_bytes = sizeof(typename Attention::SharedStorage); if (smem_bytes > 0xc000) { ORT_ENFORCE(params.sm >= 70, "This kernel requires too much shared memory on this machine!"); @@ -124,4 +275,4 @@ void DispatchBlockSize(const MemoryEfficientAttentionParams& params) { #pragma GCC diagnostic pop #endif -#endif // USE_FLASH_ATTENTION +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm50.cu b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm50.cu index 237f7ea8c9..540a269958 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm50.cu +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm50.cu @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION #include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h" @@ -21,4 +21,4 @@ void run_memory_efficient_attention_sm50(const MemoryEfficientAttentionParams& p } // namespace contrib } // namespace onnxruntime -#endif // USE_FLASH_ATTENTION +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm70.cu b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm70.cu index 941ea87baa..005425c56e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm70.cu +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm70.cu @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION #include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h" @@ -21,4 +21,4 @@ void run_memory_efficient_attention_sm70(const MemoryEfficientAttentionParams& p } // namespace contrib } // namespace onnxruntime -#endif // USE_FLASH_ATTENTION +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm75.cu b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm75.cu index 5a0e7c9ed5..955423b6c6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm75.cu +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm75.cu @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION #include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h" @@ -21,4 +21,4 @@ void run_memory_efficient_attention_sm75(const MemoryEfficientAttentionParams& p } // namespace contrib } // namespace onnxruntime -#endif // USE_FLASH_ATTENTION +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm80.cu index d0775a29c4..0b54d90c4d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm80.cu @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION #include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h" @@ -21,4 +21,4 @@ void run_memory_efficient_attention_sm80(const MemoryEfficientAttentionParams& p } // namespace contrib } // namespace onnxruntime -#endif // USE_FLASH_ATTENTION +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.cu b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.cu index 284211f965..750cace39a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.cu +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.cu @@ -1,6 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" @@ -27,4 +27,4 @@ void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params } // namespace contrib } // namespace onnxruntime -#endif // USE_FLASH_ATTENTION +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index 326ff451e6..484b783db1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -2,7 +2,7 @@ // Licensed under the MIT License. #pragma once -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cpu/bert/attention_common.h" @@ -14,10 +14,12 @@ namespace cuda { struct MemoryEfficientAttentionParams { int32_t sm; bool is_half; + bool is_kv_bsnh = true; int32_t batch_size; int32_t num_heads; int32_t sequence_length; int32_t kv_sequence_length; + int32_t max_sequence_length; int32_t qk_head_size; int32_t v_head_size; bool causal; @@ -41,6 +43,8 @@ struct MemoryEfficientAttentionParams { static bool need_workspace(size_t v_head_size, bool is_float) { return (v_head_size > 128 && !is_float); } + + bool has_custom_right_padding = false; }; void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params); @@ -58,4 +62,4 @@ void run_memory_efficient_attention_sm50(const MemoryEfficientAttentionParams& p } // namespace contrib } // namespace onnxruntime -#endif // USE_FLASH_ATTENTION +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc index f907d30060..3f703ae3d0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "contrib_ops/cuda/bert/attention_impl.h" #include "contrib_ops/cuda/bert/decoder_attention.h" +#include "contrib_ops/cuda/bert/decoder_attention_impl.h" #include "contrib_ops/cuda/bert/transformer_cuda_common.h" #include "core/framework/op_kernel.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" @@ -85,7 +85,8 @@ Status CheckInputs(const TensorShape& query_shape, } if (kv_weights_dims[0] != hidden_size || kv_weights_dims[1] != 2 * static_cast(hidden_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "kv_weights shall have shape (hidden size, 2 * hidden size)"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "kv_weights shall have shape (hidden size, 2 * hidden size)"); } const auto& bias_dims = bias_shape.GetDims(); @@ -137,7 +138,8 @@ Status CheckInputs(const TensorShape& query_shape, const auto& value_cache_dims = value_cache->Shape().GetDims(); if (value_cache_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value_cache' is expected to have 4 dimension, got ", + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'value_cache' is expected to have 4 dimension, got ", value_cache_dims.size()); } @@ -353,10 +355,12 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { } } - size_t bytes = element_size * batch_size * (static_cast(sequence_length) + static_cast(2) * kv_sequence_length) * hidden_size; + size_t bytes = element_size * batch_size * + (static_cast(sequence_length) + static_cast(2) * kv_sequence_length) * hidden_size; auto qkv_buffer_p = GetScratchBuffer(bytes, context->GetComputeStream()); - bytes = element_size * 2 * batch_size * sequence_length * num_heads_ * (static_cast(2) * head_size + static_cast(kv_sequence_length)); + bytes = element_size * 2 * batch_size * sequence_length * num_heads_ * + (static_cast(2) * head_size + static_cast(kv_sequence_length)); auto workspace_p = GetScratchBuffer(bytes, context->GetComputeStream()); Tensor* output(context->Output(0, query_shape)); diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu new file mode 100644 index 0000000000..1dc22a9c8e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu @@ -0,0 +1,263 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/bert/decoder_attention_impl.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/bert/attention_softmax.h" + +using namespace onnxruntime::contrib::attention_softmax_cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status DecoderQkvToContext( + const cudaDeviceProp& device_prop, + Stream* ort_stream, + cublasHandle_t& cublas, + const size_t element_size, + const int batch_size, + const int sequence_length, + const int kv_sequence_length, + const int num_heads, + const int head_size, + const bool static_kv, + const bool use_past, + const bool has_layer_state, + const bool has_key_padding_mask, + const float mask_filter_value, + const T* gemm_query_buffer, + const T* gemm_kv_buffer, + const bool* key_padding_mask, + const T* key_cache, + const T* value_cache, + T* qkv_buffer, + T* workspace_buffer, + T* output, + T* new_key_cache, + T* new_value_cache) { + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int BN = batch_size * num_heads; + const int BHN = BN * head_size; + const int BNS = BN * sequence_length; + const int k_buffer_offset = sequence_length * BHN; + const int v_buffer_offset = (sequence_length + kv_sequence_length) * BHN; + + T* temp_qkv_buffer = workspace_buffer; + auto stream = static_cast(ort_stream->GetHandle()); + + const T* q = qkv_buffer; + // transpose q and copy them to qkv_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_query_buffer, qkv_buffer)); + + const T* k = qkv_buffer + k_buffer_offset; + const T* v = qkv_buffer + v_buffer_offset; + if (!has_layer_state || !use_past) { + if (!static_kv) { + // transpose kv and copy them to qkv_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); + } else { + // transpose kv and copy them to qkv_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); + } + } else { + if (!static_kv) { + // transpose kv and copy them to temp_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer)); + // concat cache-k with k and copy to qkv_buffer + if (nullptr != key_cache) { + ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, + sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, 1, + key_cache, + temp_qkv_buffer, + qkv_buffer + k_buffer_offset)); + } + // concat cache-v with v and copy to qkv_buffer + if (nullptr != value_cache) { + ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, + sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, 1, + value_cache, + temp_qkv_buffer + k_buffer_offset, + qkv_buffer + v_buffer_offset)); + } + } + } + + if (has_layer_state) { + if (use_past && static_kv) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, key_cache, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, value_cache, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + } else { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, k, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, v, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + } + } + + // scratch1: BxNxSxL buffer + // scratch2: BxNxSxL buffer + // scratch3: BxNxSxH buffer + T* scratch1 = temp_qkv_buffer + 3 * BHN * sequence_length; + T* scratch2 = scratch1 + BNS * kv_sequence_length; + T* scratch3 = scratch2 + BNS * kv_sequence_length; + + // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxL + // Q: BxNxSxH, K (present_k): BxNxLxH, Q*K': BxNxSxL + const float rsqrt_head_size = 1.f / sqrt(static_cast(head_size)); + const int temp_matrix_size = sequence_length * kv_sequence_length; + float one = 1.0f; + float zero = 0.f; + + float alpha = rsqrt_head_size; + const int strideA = kv_sequence_length * head_size; + const int strideB = sequence_length * head_size; + if (use_past && static_kv) { + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, CUBLAS_OP_T, CUBLAS_OP_N, + kv_sequence_length, sequence_length, head_size, + &alpha, key_cache, head_size, strideA, + q, head_size, strideB, + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + } else { + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, CUBLAS_OP_T, CUBLAS_OP_N, + kv_sequence_length, sequence_length, head_size, + &alpha, k, head_size, strideA, + q, head_size, strideB, + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + } + + constexpr bool is_unidirectional = false; + const T* add_before_softmax = nullptr; + if (has_key_padding_mask) { + constexpr int mask_dimension = 2; + constexpr int max_sequence_length = 0; + ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask( + ort_stream, kv_sequence_length, sequence_length, batch_size, + num_heads, nullptr, key_padding_mask, add_before_softmax, + false /*broadcast rpb*/, scratch1, scratch2, is_unidirectional, + 1.0f, mask_dimension, max_sequence_length, false, nullptr, + mask_filter_value)); + } else { + ORT_RETURN_IF_ERROR(ComputeSoftmax( + stream, kv_sequence_length, sequence_length, batch_size, num_heads, + add_before_softmax, false /*broadcast rpb*/, scratch1, scratch2, + is_unidirectional)); + } + + // compute P*V (as V*P), and store in scratch3: BxNxSxH + if (use_past && static_kv) { + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, CUBLAS_OP_N, CUBLAS_OP_N, + head_size, sequence_length, kv_sequence_length, + &one, value_cache, head_size, strideA, + scratch2, kv_sequence_length, temp_matrix_size, + &zero, scratch3, head_size, strideB, BN, device_prop)); + } else { + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, CUBLAS_OP_N, CUBLAS_OP_N, + head_size, sequence_length, kv_sequence_length, + &one, v, head_size, strideA, + scratch2, kv_sequence_length, temp_matrix_size, + &zero, scratch3, head_size, strideB, BN, device_prop)); + } + + // scratch3 is BxNxSxH, transpose to output SxBxNxH + return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, scratch3, output); +} + +Status LaunchDecoderAttentionKernel( + const cudaDeviceProp& device_prop, + Stream* stream, + cublasHandle_t& cublas, + const size_t element_size, + const int batch_size, + const int sequence_length, + const int kv_sequence_length, + const int num_heads, + const int head_size, + const bool static_kv, + const bool use_past, + const bool has_layer_state, + const bool has_key_padding_mask, + const float mask_filter_value, + const void* gemm_query_buffer, + const void* gemm_kv_buffer, + const bool* key_padding_mask, + const void* key_cache, + const void* value_cache, + void* qkv_buffer, + void* workspace_buffer, + void* output, + void* new_key_cache, + void* new_value_cache) { + if (element_size == 2) { + return DecoderQkvToContext( + device_prop, + stream, + cublas, + element_size, + batch_size, + sequence_length, + kv_sequence_length, + num_heads, + head_size, + static_kv, + use_past, + has_layer_state, + has_key_padding_mask, + mask_filter_value, + reinterpret_cast(gemm_query_buffer), + reinterpret_cast(gemm_kv_buffer), + key_padding_mask, + reinterpret_cast(key_cache), + reinterpret_cast(value_cache), + reinterpret_cast(qkv_buffer), + reinterpret_cast(workspace_buffer), + reinterpret_cast(output), + reinterpret_cast(new_key_cache), + reinterpret_cast(new_value_cache)); + } else { + return DecoderQkvToContext( + device_prop, + stream, + cublas, + element_size, + batch_size, + sequence_length, + kv_sequence_length, + num_heads, + head_size, + static_kv, + use_past, + has_layer_state, + has_key_padding_mask, + mask_filter_value, + reinterpret_cast(gemm_query_buffer), + reinterpret_cast(gemm_kv_buffer), + key_padding_mask, + reinterpret_cast(key_cache), + reinterpret_cast(value_cache), + reinterpret_cast(qkv_buffer), + reinterpret_cast(workspace_buffer), + reinterpret_cast(output), + reinterpret_cast(new_key_cache), + reinterpret_cast(new_value_cache)); + } +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h new file mode 100644 index 0000000000..9db9ccb45e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cuda/bert/attention_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +Status LaunchDecoderAttentionKernel( + const cudaDeviceProp& prop, // Device Properties + Stream* stream, // ORT Stream + cublasHandle_t& cublas, // Cublas handle + const size_t element_size, // Element size of input tensor + const int batch_size, // Batch size (B) + const int sequence_length, // Sequence length (S) + const int kv_sequence_length, // Key/Value/Cache sequence length + const int num_heads, // Number of attention heads (N) + const int head_size, // Hidden size per head (H) + const bool static_kv, // Whether cross attention or not + const bool use_past, // Whether use cache or not + const bool has_layer_state, // Whether output cache or not + const bool has_key_padding_mask, // Whether use key_padding_mask or not + const float mask_filter_value, // Mask filter value + const void* gemm_query_buffer, // Query buffer + const void* gemm_kv_buffer, // Key and value buffer + const bool* key_padding_mask, // Key padding mask + const void* key_cache, // Input key cache + const void* value_cache, // Input value cache + void* qkv_buffer, // Temporary buffer + void* workspace_buffer, // Temporary buffer + void* output, // Output tensor + void* new_key_cache, // New_key_cache tensor + void* new_value_cache // New_value_cache tensor +); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu index a2dfca8cd6..ae53eca541 100644 --- a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu @@ -86,10 +86,10 @@ __global__ void MaskIndexKernel(int sequence_length, const int* mask, int* mask_ } inline Status ComputeMaskIndex(cudaStream_t stream, - const int sequence_length, - const int batch_size, - const int* mask, - int* mask_index) { + const int sequence_length, + const int batch_size, + const int* mask, + int* mask_index) { // Mask idx is of length batch_size and assumes the valid region is contiguous starting // from the beginning of the sequence @@ -133,7 +133,7 @@ __global__ void EmbedLayerNormKernel( } if (nullptr == position_ids) { position_id = blockIdx.x; - } else if (broadcast_position_ids){ + } else if (broadcast_position_ids) { position_id = position_ids[sequence_position % gridDim.x]; } else { position_id = position_ids[sequence_position]; @@ -212,13 +212,12 @@ Status LaunchEmbedLayerNormKernel( void* embedding_sum, const int* position_ids, const bool broadcast_position_ids) { - if (mask_index != nullptr) { if (nullptr == input_mask) { CUDA_RETURN_IF_ERROR(cudaMemsetAsync(mask_index, 0, sizeof(int) * batch_size, stream)); } else { ORT_RETURN_IF_ERROR( - ComputeMaskIndex(stream, sequence_length, batch_size, input_mask, static_cast(mask_index))); + ComputeMaskIndex(stream, sequence_length, batch_size, input_mask, static_cast(mask_index))); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu index 1b0de47a83..c9498eb1bc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu @@ -66,7 +66,7 @@ __global__ void FastGeluKernel2(const half2 a, const half2 b, const half2 c, int template <> Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, - const float* input, const float* bias, float* output, bool /*use_half2*/) { + const float* input, const float* bias, float* output, bool /*use_half2*/) { constexpr int blockSize = 256; const int gridSize = (input_length + blockSize - 1) / blockSize; FastGeluKernel<<>>(A, B, C, input_length, bias_length, @@ -77,7 +77,7 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int template <> Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, - const half* input, const half* bias, half* output, bool use_half2) { + const half* input, const half* bias, half* output, bool use_half2) { constexpr int blockSize = 256; if (use_half2 && 0 == (bias_length & 1) && prop.major >= 7) { const int n = input_length / 2; @@ -101,7 +101,7 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int template <> Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, - const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) { + const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) { constexpr int blockSize = 256; // remove nv_bfloat162 implementation for now to fix build issue diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h new file mode 100644 index 0000000000..811b1be7d4 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h @@ -0,0 +1,46 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +namespace onnxruntime { +namespace flash { +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockInfo { + template + __device__ BlockInfo(const Params& params, const int bidb) + : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]), + sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]), + actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + , + seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])), + actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { + } + + template + inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + } + + template + inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; + } + + const int sum_s_q; + const int sum_s_k; + const int actual_seqlen_q; + // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int seqlen_k_cache; + const int actual_seqlen_k; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace flash +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h new file mode 100644 index 0000000000..89e2351428 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h @@ -0,0 +1,114 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#include +#include + +namespace onnxruntime { +namespace flash { + +constexpr int TOTAL_DIM = 0; +constexpr int H_DIM = 1; +constexpr int D_DIM = 2; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + using index_t = uint32_t; + // The QKV matrices. + void* __restrict__ q_ptr = nullptr; + void* __restrict__ k_ptr = nullptr; + void* __restrict__ v_ptr = nullptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride = 0; + index_t k_batch_stride = 0; + index_t v_batch_stride = 0; + index_t q_row_stride = 0; + index_t k_row_stride = 0; + index_t v_row_stride = 0; + index_t q_head_stride = 0; + index_t k_head_stride = 0; + index_t v_head_stride = 0; + + // The number of heads. + int h = 0; + int h_k = 0; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio = 0; // precompute h / h_k, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + // The O matrix (output). + void* __restrict__ o_ptr = nullptr; + void* __restrict__ oaccum_ptr = nullptr; + + // The stride between rows of O. + index_t o_batch_stride = 0; + index_t o_row_stride = 0; + index_t o_head_stride = 0; + + // The pointer to the P matrix. + void* __restrict__ p_ptr = nullptr; + + // The pointer to the softmax sum. + void* __restrict__ softmax_lse_ptr = nullptr; + void* __restrict__ softmax_lseaccum_ptr = nullptr; + + // The dimensions. + int b = 0; + int seqlen_q = 0; + int seqlen_k = 0; + int seqlen_knew = 0; + int d = 0; + int seqlen_q_rounded = 0; + int seqlen_k_rounded = 0; + int d_rounded = 0; + + // The scaling factors for the kernel. + float scale_softmax = 0.0; + float scale_softmax_log2 = 0.0; + + // array of length b+1 holding starting offset of each sequence. + int* __restrict__ cu_seqlens_q = nullptr; + int* __restrict__ cu_seqlens_k = nullptr; + + int* __restrict__ blockmask = nullptr; + + // The K_new and V_new matrices. + void* __restrict__ knew_ptr = nullptr; + void* __restrict__ vnew_ptr = nullptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride = 0; + index_t vnew_batch_stride = 0; + index_t knew_row_stride = 0; + index_t vnew_row_stride = 0; + index_t knew_head_stride = 0; + index_t vnew_head_stride = 0; + + bool is_bf16 = false; + bool is_causal = false; + + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + bool is_seqlens_k_cumulative = true; + int num_splits = 0; // For split-KV version + + const cudaDeviceProp* dprops = nullptr; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream); +template +void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc new file mode 100644 index 0000000000..89a27c4d2b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -0,0 +1,424 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/bert/flash_attention/flash.h" +#include "contrib_ops/cuda/bert/flash_attention/static_switch.h" + +namespace onnxruntime { +namespace flash { + +void set_params_fprop(Flash_fwd_params& params, + // sizes + size_t batch_size, + size_t seqlen_q, + size_t seqlen_k, + size_t seqlen_q_rounded, + size_t seqlen_k_rounded, + size_t num_heads, + size_t num_heads_k, + size_t head_size, + size_t head_size_rounded, + // device pointers + void* q, + void* k, + void* v, + void* out, + void* cu_seqlens_q_d, + void* cu_seqlens_k_d, + void* p_d, + void* softmax_lse_d, + float softmax_scale, + bool is_causal, + bool kv_bsnh = true) { + // Set the pointers and strides. + params.q_ptr = q; + params.k_ptr = k; + params.v_ptr = v; + params.o_ptr = out; + + params.is_bf16 = false; + + // All stride are in elements, not bytes. + if (kv_bsnh) { + params.q_row_stride = num_heads * head_size; + params.k_row_stride = num_heads_k * head_size; + params.v_row_stride = num_heads_k * head_size; + params.q_head_stride = head_size; + params.k_head_stride = head_size; + params.v_head_stride = head_size; + params.o_row_stride = num_heads * head_size; + params.o_head_stride = head_size; + } else { + params.q_row_stride = num_heads * head_size; + params.k_row_stride = head_size; + params.v_row_stride = head_size; + params.q_head_stride = head_size; + params.k_head_stride = seqlen_k * head_size; + params.v_head_stride = seqlen_k * head_size; + params.o_row_stride = num_heads * head_size; + params.o_head_stride = head_size; + } + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0) + params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.v_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.o_batch_stride = seqlen_q * num_heads * head_size; // stride(0) + } else { + params.q_batch_stride = 0; + params.k_batch_stride = 0; + params.v_batch_stride = 0; + params.o_batch_stride = 0; + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + + // P = softmax(QK^T) + params.p_ptr = p_d; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = batch_size; + params.h = num_heads; + params.h_k = num_heads_k; + params.h_h_k_ratio = num_heads / num_heads_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = head_size; + params.d_rounded = head_size_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + + params.is_causal = is_causal; + params.is_seqlens_k_cumulative = true; +} + +size_t get_softmax_lse_size(int seqlen, int batch_size, int num_heads) { + size_t bytes = sizeof(float) * batch_size * num_heads * seqlen; + return bytes; +} + +size_t get_softmax_lse_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q) { + size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads; + return bytes; +} + +size_t get_out_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q, int head_size_rounded) { + size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads * head_size_rounded; + return bytes; +} + +void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split_kernel = false) { + FP16_SWITCH(!params.is_bf16, [&] { + FWD_HEADDIM_SWITCH(params.d, [&] { + if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_splitkv_dispatch(params, stream); + } + }); + }); +} + +// Find the number of splits that maximizes the occupancy. For example, if we have +// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is +// better than having 3 splits (efficiency = 0.67). However, we also don't want too many +// splits as that would incur more HBM reads/writes. +// So we find the best efficiency, then find the smallest number of splits that gets 85% +// of the best efficiency. +int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs, + int max_splits) { + // This needs to match with run_mha_fwd_splitkv_dispatch + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; + // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. + // In any case we don't expect seqlen_q to be larger than 64 for inference. + const int num_m_blocks = (seqlen_q + 64 - 1) / 64; + int batch_nheads_mblocks = batch_size * num_heads * num_m_blocks; + // If we have enough to almost fill the SMs, then just use 1 split + if (batch_nheads_mblocks >= 0.8f * num_SMs) { + return 1; + } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, + // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks + // (i.e. it's 11 splits anyway). + // So we check if the number of blocks per split is the same as the previous num_splits. + auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { + return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); + }; + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { + efficiency.push_back(0.f); + } else { + float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if (eff > max_efficiency) { + max_efficiency = eff; + } + efficiency.push_back(eff); + } + } + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { + continue; + } + if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} + +// Returns (num_splits, softmax_lse_accum bytes, out_accum bytes) +std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, + int head_size, int num_SMs) { + int max_splits = 128; + // split kv buffers + int num_splits = num_splits_heuristic(batch_size, seqlen_q, seqlen_k, num_heads, head_size, + num_SMs, max_splits); + if (num_splits > 1) { + // softmax_lse_accum buffer + int softmax_lse_accum_bytes = get_softmax_lse_accum_size(num_splits, batch_size, num_heads, seqlen_q); + // out_accum buffer + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + int out_accum_bytes = get_out_accum_size(num_splits, batch_size, num_heads, seqlen_q, head_size_rounded); + return {num_splits, softmax_lse_accum_bytes, out_accum_bytes}; + } else { + return {0, 0, 0}; + } +} + +Status mha_fwd(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* k, // batch_size x seqlen_k x num_heads_k x head_size + void* v, // batch_size x seqlen_k x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + float softmax_scale, + bool is_causal, + int num_splits, + void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + bool kv_bsnh) { + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + /*cu_seqlens_q*/ nullptr, + /*cu_seqlens_k*/ nullptr, + nullptr, + softmax_lse, + softmax_scale, + is_causal, + kv_bsnh); + params.dprops = &dprops; + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; + params.knew_batch_stride = 0; + params.vnew_batch_stride = 0; + params.knew_row_stride = 0; + params.vnew_row_stride = 0; + params.knew_head_stride = 0; + params.vnew_head_stride = 0; + + params.num_splits = num_splits; + if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) { + params.softmax_lseaccum_ptr = softmax_lse_accum; + params.oaccum_ptr = out_accum; + } else { + params.softmax_lseaccum_ptr = nullptr; + params.oaccum_ptr = nullptr; + } + + run_mha_fwd(params, stream); + return Status::OK(); +} + +Status mha_varlen_fwd(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // half (total_q, num_heads, head_size) + void* k, // half (total_k, num_heads, head_size) + void* v, // half (total_k, num_heads, head_size) + void* out, // half (total_q, num_heads, head_size) + int* cu_seqlens_q, // int (batch_size + 1) + int* cu_seqlens_k, // int (batch_size + 1) + void* softmax_lse, // float (batch_size, num_heads, max_seqlen_q) + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int max_seqlen_q, + int max_seqlen_k, + float softmax_scale, + bool is_causal) { + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + max_seqlen_q, max_seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + cu_seqlens_q, + cu_seqlens_k, + nullptr, + softmax_lse, + softmax_scale, + is_causal); + params.dprops = &dprops; + params.num_splits = 0; + params.softmax_lseaccum_ptr = nullptr; + params.oaccum_ptr = nullptr; + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; + run_mha_fwd(params, stream); + return Status::OK(); +} + +bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k) { + bool is_sm8x = dprops.major == 8 && dprops.minor >= 0; + bool is_sm90 = dprops.major == 9 && dprops.minor == 0; + return (is_sm8x || is_sm90) && (head_size % 8 == 0) && (head_size <= 256) && (num_heads % num_heads_k == 0); +} + +// This API is used when past key and value are present... since cached, these are assumed to have sequence length +// of max_sequence_length, so seqlen_k == max_sequence_length. The actual past sequence length is held in seqlens_k_. +Status mha_fwd_kvcache(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size + void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size + void* k, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* v, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + void* seqlens_k_, // batch_size + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + int seqlen_k_new, + const float softmax_scale, + bool is_causal, + bool past_bsnh, // otherwise bnsh + int num_splits, + void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded +) { + if (seqlen_q == 1) { + is_causal = false; + } // causal=true is the same as causal=false in this case + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, kcache, vcache, out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*p_ptr=*/nullptr, + softmax_lse, + softmax_scale, + is_causal, + past_bsnh); + params.dprops = &dprops; + + if (k != nullptr && v != nullptr) { + params.seqlen_knew = seqlen_k_new; + params.knew_ptr = k; + params.vnew_ptr = v; + // All stride are in elements, not bytes. + params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.knew_row_stride = num_heads_k * head_size; + params.vnew_row_stride = num_heads_k * head_size; + params.knew_head_stride = head_size; + params.vnew_head_stride = head_size; + } else { + params.seqlen_knew = 0; + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; + params.knew_batch_stride = 0; + params.vnew_batch_stride = 0; + params.knew_row_stride = 0; + params.vnew_row_stride = 0; + params.knew_head_stride = 0; + params.vnew_head_stride = 0; + } + + params.is_seqlens_k_cumulative = seqlens_k_ == nullptr; + if (seqlens_k_ != nullptr) { + params.cu_seqlens_k = static_cast(seqlens_k_); + } + + params.num_splits = num_splits; + if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) { + params.softmax_lseaccum_ptr = softmax_lse_accum; + params.oaccum_ptr = out_accum; + } else { + params.softmax_lseaccum_ptr = nullptr; + params.oaccum_ptr = nullptr; + } + + // Only split kernel supports appending to KV cache + run_mha_fwd(params, stream, /*force_split_kernel=*/k != nullptr); + + return Status::OK(); +} + +} // namespace flash +} // namespace onnxruntime + +#endif // USE_FLASH_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h new file mode 100644 index 0000000000..58f4304251 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -0,0 +1,112 @@ +/****************************************************************************** + * Copyright (c) 2022, Tri Dao. + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#if USE_FLASH_ATTENTION + +#include "core/providers/cuda/cuda_common.h" +#include + +namespace onnxruntime { +namespace flash { + +Status mha_fwd(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* k, // batch_size x seqlen_k x num_heads_k x head_size + void* v, // batch_size x seqlen_k x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + float softmax_scale, + bool is_causal, + int num_splits = 0, + void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + bool kv_bsnh = true); + +Status mha_varlen_fwd(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // half (total_q, num_heads, head_size) + void* k, // half (total_k, num_heads, head_size) + void* v, // half (total_k, num_heads, v_head_size) + void* out, // half (total_q, num_heads, v_head_size) + int* cu_seqlens_q, // int (batch_size + 1) + int* cu_seqlens_k, // int (batch_size + 1) + void* softmax_lse, // float (batch_size, num_heads, max_seqlen_q) + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int max_seqlen_q, + int max_seqlen_k, + float softmax_scale, + bool is_causal); + +Status mha_fwd_kvcache(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size + void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size + void* k, // batch_size x seqlen_k_new x num_heads_k x head_size + void* v, // batch_size x seqlen_k_new x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + void* seqlens_k_, // batch_size + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + int seqlen_k_new, + const float softmax_scale, + bool is_causal, + bool past_bsnh, // otherwise bnsh + int num_splits = 0, + void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum = nullptr // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded +); + +size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); + +std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, + int head_size, int num_SMs); + +bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k); + +} // namespace flash +} // namespace onnxruntime + +#endif // USE_FLASH_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_fp16_sm80.cu new file mode 100644 index 0000000000..44ea92e58c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_fp16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim160_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim160_fp16_sm80.cu new file mode 100644 index 0000000000..a2bf16bc74 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim160_fp16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_fp16_sm80.cu new file mode 100644 index 0000000000..56fc04126a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_fp16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim224_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim224_fp16_sm80.cu new file mode 100644 index 0000000000..6fb2464071 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim224_fp16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_fp16_sm80.cu new file mode 100644 index 0000000000..94d51e922d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_fp16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_fp16_sm80.cu new file mode 100644 index 0000000000..d32eec2763 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_fp16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_fp16_sm80.cu new file mode 100644 index 0000000000..65a2e42192 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_fp16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_fp16_sm80.cu new file mode 100644 index 0000000000..f37ee50058 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_fp16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h new file mode 100644 index 0000000000..eb1c794d6d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h @@ -0,0 +1,1240 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-variable" +#pragma GCC diagnostic ignored "-Wunused-but-set-variable" +#endif + +#include +#include +#include + +#include +#include +#include +#include + +#include "contrib_ops/cuda/bert/flash_attention/block_info.h" +#include "contrib_ops/cuda/bert/flash_attention/kernel_traits.h" +#include "contrib_ops/cuda/bert/flash_attention/utils.h" +#include "contrib_ops/cuda/bert/flash_attention/softmax.h" + +namespace onnxruntime { +namespace flash { +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE auto +make_tiled_copy_A_warpcontiguousM(Copy_Atom const& copy_atom, + TiledMMA const& tiled_mma) { + using TileShape_MNK = typename TiledMMA::TiledShape_MNK; + using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; + constexpr int AtomShape_M = decltype(cute::size<0>(AtomShape_MNK{}))::value; + constexpr int kNWarps = decltype(cute::size<0>(TileShape_MNK{}))::value / AtomShape_M; + constexpr int MMAStride_M = MMA_M * AtomShape_M; + auto t = make_tile(cute::Layout, cute::Int>, + cute::Stride<_1, cute::Int>>{}, + make_layout(cute::size<2>(TileShape_MNK{}))); + + return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE auto +make_tiled_copy_C_warpcontiguousM(Copy_Atom const& copy_atom, + TiledMMA const& tiled_mma) { + using TileShape_MNK = typename TiledMMA::TiledShape_MNK; + using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; + constexpr int AtomShape_M = decltype(cute::size<0>(AtomShape_MNK{}))::value; + constexpr int kNWarps = decltype(cute::size<0>(TileShape_MNK{}))::value / AtomShape_M; + constexpr int MMAStride_M = MMA_M * AtomShape_M; + auto t = make_tile(cute::Layout, cute::Int>, + cute::Stride<_1, cute::Int>>{}, + // TODO: Shouldn't this be size<1>? + make_layout(cute::size<2>(TileShape_MNK{}))); + // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); } + return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, Tensor1& scores_sum, + Tensor2& acc_o, float softmax_scale_log2) { + if (Is_first) { + flash::template reduce_max(scores, scores_max); + flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + flash::reduce_sum(scores, scores_sum); + } else { + cute::Tensor scores_max_prev = make_fragment_like(scores_max); + cute::copy(scores_max, scores_max_prev); + flash::template reduce_max(scores, scores_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); +#pragma unroll + for (int mi = 0; mi < cute::size(scores_max); ++mi) { + float scores_max_cur = !Check_inf + ? scores_max(mi) + : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + scores_sum(mi) *= scores_scale; +#pragma unroll + for (int ni = 0; ni < cute::size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scores_scale; + } + } + flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + cute::Tensor scores_sum_cur = make_fragment_like(scores_sum); + flash::reduce_sum(scores, scores_sum_cur); +#pragma unroll + for (int mi = 0; mi < cute::size(scores_sum); ++mi) { + scores_sum(mi) += scores_sum_cur(mi); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void write_softmax_to_gmem( + cute::Tensor const& tOrP, cute::Tensor& tPgP, TiledCopy gmem_tiled_copy_P) { + // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) + cute::Layout l = tOrP.layout(); + cute::Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); + CUTE_STATIC_ASSERT_V(cute::size<2>(tPgP) == _1{}); + CUTE_STATIC_ASSERT_V(cute::size<1>(tPrP) == cute::size<1>(tPgP)); +#pragma unroll + for (int mi = 0; mi < cute::size<1>(tPrP); ++mi) { + cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_1rowblock(const Params& params, const int bidb, const int bidh, const int m_block) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + constexpr int MMA_M = kBlockM / decltype(cute::size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; + + int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + if (Is_causal) { + n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); + // We exit early and write 0 to gO and gLSE. + // Otherwise we might read OOB elements from gK and gV. + if (n_block_max <= 0) { + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); +#pragma unroll + for (int m = 0; m < size<1>(tOgO); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { + gLSE(row) = INFINITY; + } + } + return; + } + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + // We move K and V to the last block. + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + + cute::Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + cute::Shape, cute::Int>{}, + make_stride(params.q_row_stride, _1{})); + cute::Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + cute::Shape, cute::Int>{}, + make_stride(params.k_row_stride, _1{})); + cute::Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + cute::Shape, cute::Int>{}, + make_stride(params.v_row_stride, _1{})); + cute::Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), + cute::Shape, cute::Int>{}, + make_stride(params.seqlen_k_rounded, _1{})); + + cute::Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; + cute::Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : cute::size(sQ)), + typename Kernel_traits::SmemLayoutKV{}); + cute::Tensor sV = make_tensor(sK.data() + cute::size(sK), typename Kernel_traits::SmemLayoutKV{}); + cute::Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + cute::Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P; + auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx); + + cute::Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + cute::Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + cute::Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + cute::Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + cute::Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + cute::Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + cute::Tensor tPgP = gmem_thr_copy_P.partition_D(gP); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + cute::Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + cute::Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + cute::Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + cute::Tensor acc_o = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + cute::Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + cute::Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + cute::Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // TODO: this might need to change if we change the mma instruction in SM70 + cute::Tensor scores_max = make_tensor(cute::Shape(acc_o)>>{}); + cute::Tensor scores_sum = make_fragment_like(scores_max); + + // + // PREDICATES + // + + // Construct identity layout for sQ and sK + cute::Tensor cQ = make_identity_tensor(make_shape(cute::size<0>(sQ), cute::size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + cute::Tensor cKV = make_identity_tensor(make_shape(cute::size<0>(sK), cute::size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + cute::Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + cute::Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + cute::Tensor tQpQ = make_tensor(make_shape(cute::size<2>(tQsQ))); + cute::Tensor tKVpKV = make_tensor(make_shape(cute::size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < cute::size(tQpQ); ++k) { + tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; + } +#pragma unroll + for (int k = 0; k < cute::size(tKVpKV); ++k) { + tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; + } + } + + // Prologue + + cute::Tensor tQrQ = make_fragment_like(tQgQ); + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + if (Kernel_traits::Is_Q_in_regs) { + cute::cp_async_fence(); + } + + if (Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<0>(); + __syncthreads(); + cute::Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(cute::size<1>(tSsQ) == cute::size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + __syncthreads(); + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + + if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<1>(); + __syncthreads(); + cute::Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(cute::size<1>(tSsQ) == cute::size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + } + + clear(acc_o); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = !Is_causal + ? 1 + : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); +#pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + // if (cute::thread0()) { print(acc_s); } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + cute::Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + + // We don't put the masking before the matmul S = Q K^T because we don't clear sK + // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul + // can produce Inf / NaN. + if (!Is_causal) { + if (!Is_even_MN) { + flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); + } + } else { + // I can't get the stride from idx_row + flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, + kNWarps * 16); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > 0) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // TODO: when we have key_padding_mask we'll need to Check_inf + masking_step == 0 + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + // Convert scores from fp32 to fp16/bf16 + cute::Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + cute::Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + // if (Return_softmax) { + // cute::Tensor tOrP_copy = make_fragment_like(tOrP); + // copy(tOrP, tOrP_copy); + // flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); + // tPgP.data() = tPgP.data() + (-kBlockN); + // } + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= 0) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= 0; --n_block) { + cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > 0) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + cute::Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + cute::Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + cute::Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + // if (Return_softmax) { + // cute::Tensor tOrP_copy = make_fragment_like(tOrP); + // copy(tOrP, tOrP_copy); + // flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); + // tPgP.data() = tPgP.data() + (-kBlockN); + // } + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + cute::Tensor lse = make_fragment_like(scores_sum); +#pragma unroll + for (int mi = 0; mi < cute::size<0>(acc_o_rowcol); ++mi) { + float sum = scores_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum); + float scale = inv_sum; +#pragma unroll + for (int ni = 0; ni < cute::size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scale; + } + } + + // Convert acc_o from fp32 to fp16/bf16 + cute::Tensor rO = flash::convert_type(acc_o); + cute::Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); + cute::Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + cute::Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { + __syncthreads(); + } + + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + cute::Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + cute::Shape, cute::Int>{}, + make_stride(params.o_row_stride, _1{})); + cute::Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + cute::Shape>{}, cute::Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + cute::Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + cute::Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + + __syncthreads(); + + cute::Tensor tOrO = make_tensor(cute::shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + cute::Tensor caccO = make_identity_tensor(cute::Shape, cute::Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + cute::Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(cute::size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + cute::Tensor taccOcO_row = logical_divide(taccOcO, cute::Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(cute::size(lse) == cute::size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < cute::size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { + gLSE(row) = lse(mi); + } + } + } + + // Construct identity layout for sO + cute::Tensor cO = make_identity_tensor(make_shape(cute::size<0>(sO), cute::size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + cute::Tensor tOpO = make_tensor(make_shape(cute::size<2>(tOgO))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < cute::size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + using GmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::GmemTiledCopyOaccum, + typename Kernel_traits::GmemTiledCopyO>; + using ElementO = std::conditional_t; + + const BlockInfo binfo(params, bidb); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } + // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; + const int n_block_min = n_split_idx * n_blocks_per_split; + int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); + if (Is_causal) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); + } + if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 + // We exit early and write 0 to gOaccum and -inf to gLSEaccum. + // Otherwise we might read OOB elements from gK and gV, + // or get wrong results when we combine gOaccum from different blocks. + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrOaccum); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); +#pragma unroll + for (int m = 0; m < size<1>(tOgOaccum); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { + gLSEaccum(row) = Split ? -INFINITY : INFINITY; + } + } + return; + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + // We move K and V to the last block. + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; + const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, + // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. + // This maps to accessing the first 64 rows of knew_ptr. + Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), + Shape, Int>{}, + make_stride(params.knew_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } + Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + Shape, Int>{}, + make_stride(params.vnew_row_stride, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // TODO: this might need to change if we change the mma instruction in SM70 + Tensor scores_max = make_tensor(Shape(acc_o)>>{}); + Tensor scores_sum = make_fragment_like(scores_max); + + // + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { + tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; + } +#pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { + tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; + } + } + + // Prologue + + Tensor tQrQ = make_fragment_like(tQgQ); + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy_2_sources( + gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + cute::cp_async_fence(); + + // flash::cp_async_wait<0>(); + // __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } + // __syncthreads(); + + clear(acc_o); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = !Is_causal + ? 1 + : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); +#pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + if constexpr (Append_KV) { + // if (cute::thread0()) { print(tKgK); } + // if (cute::thread0()) { print(tKsK); } + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } + if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { + flash::copy_w_min_idx( + tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + } + // __syncthreads(); + // if (cute::thread0()) { print(tKgK); } + // __syncthreads(); + } + + // Advance gV + if (masking_step > 0) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + if (Append_KV) { + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + } + flash::copy_2_sources( + gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy_2_sources( + gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + // if (cute::thread0()) { print(acc_s); } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + // if (cute::thread0()) { print(scores); } + // We don't put the masking before the matmul S = Q K^T because we don't clear sK + // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul + // can produce Inf / NaN. + if (!Is_causal) { + if (!Is_even_MN) { + flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); + } + } else { + flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, + kNWarps * 16); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } + // __syncthreads(); + + // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("n_block = %d, n_block_min = %d\n", n_block, n_block_min); } + if constexpr (Append_KV) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } + if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { + flash::copy_w_min_idx( + tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + } + } + + if (n_block > n_block_min) { + // Advance gK + // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p\n", tKgKnew.data()); } + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + if (Append_KV) { + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + } + // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p, row_idx_switch = %d\n", tKgKnew.data(), binfo.seqlen_k_cache - (n_block - 1) * kBlockN); } + flash::copy_2_sources( + gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0, + binfo.seqlen_k_cache - (n_block - 1) * kBlockN); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // We have key_padding_mask so we'll need to Check_inf + masking_step == 0 + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + // if (cute::thread0()) { print(scores); } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + if constexpr (Append_KV) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } + if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { + flash::copy_w_min_idx( + tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + } + } + // Advance gV + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + if (Append_KV) { + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + } + flash::copy_2_sources( + gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + + flash::cp_async_wait<0>(); + __syncthreads(); + if constexpr (Append_KV) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } + if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { + flash::copy_w_min_idx( + tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + } + } + if (n_block > n_block_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + if (Append_KV) { + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + } + flash::copy_2_sources( + gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0, + binfo.seqlen_k_cache - (n_block - 1) * kBlockN); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + // if (cute::thread0()) { print(acc_o_rowcol); } + Tensor lse = make_fragment_like(scores_sum); +#pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = scores_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : scores_max(mi) * params.scale_softmax + __logf(sum); + float scale = inv_sum; +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scale; + } + } + // if (cute::thread0()) { print(lse); } + // if (cute::thread0()) { print(acc_o_rowcol); } + + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::SmemCopyAtomO, + typename Kernel_traits::SmemCopyAtomOaccum>; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = flash::convert_type(acc_o); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sOaccum is larger than sQ, so we need to syncthreads here + // TODO: allocate enough smem for sOaccum + if constexpr (Split) { + __syncthreads(); + } + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); } + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + __syncthreads(); + + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { + gLSEaccum(row) = lse(mi); + } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); + // __syncthreads(); + // if (cute::thread0()) { print(tOgOaccum); } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn(const Params& params) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + + // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting + // them to have the same number of threads or have to traverse the attention matrix + // in the same order. + // In the Philox RNG, we use the offset to store the batch, head, and the lane id + // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within + // the attention matrix. This way, as long as we have the batch, head, and the location of + // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. + + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_splitkv(const Params& params) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = Split ? blockIdx.z / params.h : blockIdx.y; + // The block index for the head. + const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; + const int n_split_idx = Split ? blockIdx.y : 0; + const int num_n_splits = Split ? gridDim.y : 1; + flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void combine_attn_seqk_parallel(const Params& params) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + constexpr int kMaxSplits = 1 << Log_max_splits; + constexpr int kBlockM = 16; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); + // static_assert(kMaxSplits <= 8, "kMaxSplits must be <= 8 for now, will extend layer"); + static_assert(kBlockM == 16 || kBlockM == 32, "kBlockM must be 16 or 32"); + static_assert(Kernel_traits::kNThreads == 128, "We assume that each block has 128 threads"); + + // Shared memory. + // kBlockM + 1 instead of kBlockM to reduce bank conflicts. + __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1]; + + // The thread and block index. + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + + const index_t row_offset_lse = bidx * kBlockM; + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), + Shape, Int>{}, + make_stride(params.b * params.h * params.seqlen_q, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + Kernel_traits::kNThreads - 1) / Kernel_traits::kNThreads; + + // Read the LSE values from gmem and store them in shared memory, then tranpose them. + constexpr int kRowsPerLoadLSE = Kernel_traits::kNThreads / kBlockM; +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadLSE + tidx / kBlockM; + const int col = tidx % kBlockM; + ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; + if (row < kMaxSplits) { + sLSE[row][col] = lse; + } + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + } + // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } + __syncthreads(); + Tensor lse_accum = make_tensor(Shape>{}); + constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); + // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits + // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, + // 16 rows, so each time we load we can load 8 rows). + // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; + // static_assert(kThreadsPerSplit <= 32); + static_assert(kRowsPerLoadTranspose <= 32); + static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + } + + // Compute the logsumexp of the LSE along the split dimension. + ElementAccum lse_max = lse_accum(0); +#pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { + lse_max = max(lse_max, lse_accum(l)); + } + MaxOp max_op; + lse_max = Allreduce::run(lse_max, max_op); + lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf + float lse_sum = expf(lse_accum(0) - lse_max); +#pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { + lse_sum += expf(lse_accum(l) - lse_max); + } + SumOp sum_op; + lse_sum = Allreduce::run(lse_sum, sum_op); + // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise + // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. + ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } + if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { + gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; + } +// Store the scales exp(lse - lse_logsum) in shared memory. +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + if (row < params.num_splits && col < kBlockM) { + sLSE[row][col] = expf(lse_accum(l) - lse_logsum); + } + } + __syncthreads(); + + const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), + Shape, Int>{}, + Stride, _1>{}); + typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); + Tensor tOrO = make_tensor(shape(tOgOaccum)); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrO); + + // Predicates + Tensor cOaccum = make_identity_tensor(Shape, Int>{}); + // Repeat the partitioning with identity layouts + Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); + Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpOaccum); ++k) { + tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; + } + } +// Load Oaccum in then scale and accumulate to O +#pragma unroll 2 + for (int split = 0; split < params.num_splits; ++split) { + flash::copy( + gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM); +#pragma unroll + for (int m = 0; m < size<1>(tOrOaccum); ++m) { + int row = get<0>(tOcOaccum(0, m, 0)); + ElementAccum lse_scale = sLSE[split][row]; +#pragma unroll + for (int k = 0; k < size<2>(tOrOaccum); ++k) { +#pragma unroll + for (int i = 0; i < size<0>(tOrOaccum); ++i) { + tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); + } + } + // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); print(tOrO); } + } + tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; + } + // if (cute::thread0()) { print(tOrO); } + + Tensor rO = flash::convert_type(tOrO); +// Write to gO +#pragma unroll + for (int m = 0; m < size<1>(rO); ++m) { + const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0)); + if (idx < params.b * params.h * params.seqlen_q) { + const int batch_idx = idx / (params.h * params.seqlen_q); + const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q; + // The index to the rows of Q + const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q; + auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride; +#pragma unroll + for (int k = 0; k < size<2>(rO); ++k) { + if (Is_even_K || tOpOaccum(k)) { + const int col = get<1>(tOcOaccum(0, m, k)); + Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), + Shape(rO))::value>>{}, Stride<_1>{}); + // TODO: Should check if this is using vectorized store, but it seems pretty fast + copy(rO(_, m, k), gO); + // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); } + // reinterpret_cast(o_ptr)[col / 4] = recast(rO)(0, m, k); + } + } + } + } +} + +} // namespace flash +} // namespace onnxruntime + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h new file mode 100644 index 0000000000..82dfa59b8f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h @@ -0,0 +1,287 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#include "contrib_ops/cuda/bert/flash_attention/static_switch.h" +#include "contrib_ops/cuda/bert/flash_attention/flash.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h" + +namespace onnxruntime { +namespace flash { + +template +__global__ void flash_fwd_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + flash::compute_attn(params); +#else + (void)params; +#endif +} + +template +__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + flash::compute_attn_splitkv(params); +#else + (void)params; +#endif +} + +template +__global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static_assert(Log_max_splits >= 1); + flash::combine_attn_seqk_parallel(params); +#else + (void)params; +#endif +} + +template +void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) { + constexpr size_t smem_size = Kernel_traits::kSmemSize; + + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.b, params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + // Will only return softmax if dropout, to reduce compilation time. + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + // ORT_ENFORCE(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + }); + }); +} + +template +void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { + static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); + static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); + constexpr size_t smem_size = Kernel_traits::kSmemSize; + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(params.num_splits > 1, Split, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr); + auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV, IsEvenKConst, Split, Append_KV > ; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + kernel<<>>(params); + }); + }); + }); + }); + }); + if (params.num_splits > 1) { + dim3 grid_combine((params.b * params.h * params.seqlen_q + 16 - 1) / 16); + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + if (params.num_splits <= 2) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 4) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 8) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 16) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 32) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 64) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 128) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } + }); + } +} + +template +void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int kBlockM = 64; // Fixed for all head dimensions + constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_flash_splitkv_fwd>(params, stream); +} + +template +void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int Headdim = 32; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_flash_fwd, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int Headdim = 64; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower + // Using block size (64 x 256) is 27% slower for seqlen=2k + // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling + run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_hdim96(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int Headdim = 96; + const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + if (is_sm8x) { + if constexpr (!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // These two are always slower + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} + +template +void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int Headdim = 128; + const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. + if (is_sm8x) { + if constexpr (!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // 1st ones are good for H100, A100 + // 2nd one is good for A6000 bc we get slightly better occupancy + }); +} + +template +void run_mha_fwd_hdim160(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int Headdim = 160; + const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For A100, H100, 128 x 32 is the fastest. + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 64 with 8 warps is the fastest for non-causal. + if (is_sm8x) { + if constexpr (!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} + +template +void run_mha_fwd_hdim192(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int Headdim = 192; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} + +template +void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) { + constexpr size_t Headdim = 224; + constexpr size_t threshold = 2 * Headdim * (128 + 2 * 64); + size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; + // printf("max_smem_per_block = %d\n", max_smem_per_block); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if (max_smem_per_block >= threshold) { // 112 KB + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32. + // If we have N = 32, there are only 1024 elements to load at once, where each load + // is 8 elements. This means we can only use 128 threads and not 256 threads. + // run_flash_fwd, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream) { + constexpr size_t Headdim = 256; + constexpr size_t min_threshold = 2 * Headdim * (128 + 2 * 64); + constexpr size_t max_threshold = 4 * Headdim * (64 + 2 * 64); + size_t max_smem_per_sm = params.dprops->sharedMemPerMultiprocessor; + size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; + // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For A100, we want to run with 128 x 64 (128KB smem). + // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. + if (max_smem_per_block >= min_threshold && max_smem_per_sm < max_threshold) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // 64 KB + // run_flash_fwd, Is_causal>(params, stream); + // 96 KB + // run_flash_fwd, Is_causal>(params, stream); + }); +} + +} // namespace flash +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu new file mode 100644 index 0000000000..68ae2ea759 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu new file mode 100644 index 0000000000..94564a6aba --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu new file mode 100644 index 0000000000..ec9e9e738c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu new file mode 100644 index 0000000000..e6c4ff5d95 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu new file mode 100644 index 0000000000..552966852c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu new file mode 100644 index 0000000000..e9f191a482 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu new file mode 100644 index 0000000000..d628a55668 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu new file mode 100644 index 0000000000..88b6cc0fb1 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h new file mode 100644 index 0000000000..134f159e25 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h @@ -0,0 +1,362 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#include + +#include +#include +#include + +using namespace cute; + +namespace onnxruntime { +namespace flash { + +template +struct Flash_kernel_traits { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = uint32_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom>; + using ValLayoutMNK = cute::Layout>; +#else + using MMA_Atom_Arch = MMA_Atom; + using ValLayoutMNK = cute::Layout>; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, _1, _1>>, // 4x1x1 or 8x1x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQ = decltype(composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + using SmemLayoutAtomVtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomVtransposed = decltype(composition(Swizzle{}, SmemLayoutAtomVtransposedNoSwizzle{})); + using SmemLayoutVtransposed = decltype(tile_to_shape( + SmemLayoutAtomVtransposed{}, + Shape, Int>{})); + // Maybe the VtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomVtransposedNoSwizzle{}, + Shape, Int>{})); + // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + + using SmemLayoutAtomO = decltype(composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom; + + static constexpr int kSmemQCount = cute::size(SmemLayoutQ{}); + static constexpr int kSmemKVCount = cute::size(SmemLayoutKV{}) * 2; + static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = cute::Layout, cute::Int>, + cute::Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy>; + using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per store + static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); + using GmemLayoutAtomP = cute::Layout, cute::Int>, + cute::Stride, _1>>; + + using GmemTiledCopyP = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomP{}, + cute::Layout>{})); // Val layout, 8 vals per store + + using GmemLayoutAtomOaccum = std::conditional_t< + kBlockKSmem == 32, + cute::Layout, // Thread layout, 8 threads per row + cute::Stride<_8, _1>>, + cute::Layout, // Thread layout, 16 threads per row + cute::Stride<_16, _1>>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + cute::Layout>{})); // Val layout, 4 vals per store +}; + +// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. +// No_double_buffer is another option to reduce smem usage, but will slow things down. +template > +struct Flash_bwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Is_V_in_regs = Is_V_in_regs_; + static constexpr bool No_double_buffer = No_double_buffer_; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; + static_assert(kNWarps % AtomLayoutMSdP == 0); + static_assert(kNWarps % AtomLayoutNdKV == 0); + static_assert(kNWarps % AtomLayoutMdQ == 0); + + using TiledMmaSdP = TiledMMA< + typename Base::MMA_Atom_Arch, + cute::Layout, cute::Int, _1>>, + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using TiledMmadKV = TiledMMA< + typename Base::MMA_Atom_Arch, + cute::Layout, cute::Int, _1>>, + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using TiledMmadQ = TiledMMA< + typename Base::MMA_Atom_Arch, + cute::Layout, cute::Int, _1>>, // 2x4x1 or 4x2x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQdO = decltype(composition(Swizzle{}, + cute::Layout>, + cute::Stride, _1>>{})); + using SmemLayoutQdO = decltype(tile_to_shape( + SmemLayoutAtomQdO{}, + cute::make_shape(cute::Int{}, cute::Int{}))); + + using SmemLayoutAtomKV = decltype(composition(Swizzle{}, + cute::Layout, cute::Int>, + cute::Stride, _1>>{})); + using SmemLayoutKV = decltype(tile_to_shape( + // SmemLayoutAtomQdO{}, + SmemLayoutAtomKV{}, + cute::make_shape(cute::Int{}, cute::Int{}))); + + using SmemLayoutAtomKtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomKtransposed = decltype(composition(Swizzle{}, SmemLayoutAtomKtransposedNoSwizzle{})); + using SmemLayoutKtransposed = decltype(tile_to_shape( + SmemLayoutAtomKtransposed{}, + make_shape(Int{}, Int{}))); + // Maybe the KtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomKtransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); + + // TODO: generalize to other values of kBlockN + // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 + // static constexpr int kPBlockN = kBlockN; + static_assert(kBlockN >= 64); + // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. + static constexpr int kPBlockN = 64; + static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); + // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); + static constexpr int kSwizzlePdS = 3; + using SmemLayoutAtomPdS = decltype(composition(Swizzle{}, + cute::Layout, cute::Int>, + cute::Stride, _1>>{})); + using SmemLayoutPdS = decltype(tile_to_shape( + SmemLayoutAtomPdS{}, + cute::make_shape(cute::Int{}, cute::Int{}))); + using SmemLayoutAtomPdStransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomPdStransposed = decltype(composition(Swizzle{}, SmemLayoutAtomPdStransposedNoSwizzle{})); + using SmemLayoutPdStransposed = decltype(tile_to_shape( + SmemLayoutAtomPdStransposed{}, + make_shape(Int{}, Int{}))); + using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomPdStransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); + using SmemCopyAtomPdS = Copy_Atom; + + using SmemLayoutAtomQdOtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomQdOtransposed = decltype(composition(Swizzle{}, SmemLayoutAtomQdOtransposedNoSwizzle{})); + using SmemLayoutQdOtransposed = decltype(tile_to_shape( + SmemLayoutAtomQdOtransposed{}, + make_shape(Int{}, Int{}))); + using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomQdOtransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); + + using SmemLayoutAtomdKV = decltype(composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdKV = decltype(tile_to_shape( + SmemLayoutAtomdKV{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdKV = Copy_Atom; + + using SmemLayoutAtomdQ = decltype(composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdQ = decltype(tile_to_shape( + SmemLayoutAtomdQ{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdQ = Copy_Atom; + + static constexpr int kSmemQdOCount = cute::size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ + static constexpr int kSmemKVCount = cute::size(SmemLayoutKV{}) * 2; + static constexpr int kSmemdSCount = cute::size(SmemLayoutPdS{}); + static constexpr int kSmemPCount = cute::size(SmemLayoutPdS{}); + static constexpr int kSmemdQCount = cute::size(SmemLayoutdQ{}); + // static constexpr int kSmemdPsumCount = kBlockM; + static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); + static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); + static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); + // static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum); + static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); + static constexpr int kSmemSize1colblock = kSmemQdOSize + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + kSmemPSize + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); + static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3 + kSmemdSSize + kSmemPSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem + // to affect speed in practice. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = cute::Layout, cute::Int>, + cute::Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy>; + using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopydO = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydQ = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per store + using GmemLayoutAtomdQaccum = std::conditional_t< + kBlockKSmem == 32, + cute::Layout, // Thread layout, 8 threads per row + cute::Stride<_8, _1>>, + cute::Layout, // Thread layout, 16 threads per row + cute::Stride<_16, _1>>>; + using GmemTiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomdQaccum{}, + cute::Layout>{})); // Val layout, 4 vals per store + + using GmemTiledCopydQaccumAtomicAdd = decltype(make_tiled_copy(Copy_Atom{}, + cute::Layout, // Thread layout, 8 threads per row + cute::Stride<_32, _1>>{}, + cute::Layout>{})); // Val layout, 1 val per store +}; + +} // namespace flash +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h new file mode 100644 index 0000000000..842edf3a98 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h @@ -0,0 +1,206 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#include + +#include + +#include +#include + +#include "contrib_ops/cuda/bert/flash_attention/utils.h" + +namespace onnxruntime { +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void thread_reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ inline void quad_allreduce_(Tensor& dst, Tensor& src, Operator& op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); +#pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ inline void reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ inline void reduce_max(Tensor const& tensor, Tensor& max) { + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ inline void reduce_sum(Tensor const& tensor, Tensor& sum) { + SumOp sum_op; + reduce_(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template +inline __device__ void scale_apply_exp2(Tensor& tensor, Tensor const& max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + } + } +} + +// Apply the exp to all the elements. +template +inline __device__ void max_scale_exp2_sum(Tensor& tensor, Tensor& max, Tensor& sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +template +inline __device__ void apply_mask(Tensor& tensor, const int max_seqlen_k, + const int col_idx_offset_ = 0) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= max_seqlen_k) { +// Without the "make_coord" we get wrong results +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) = -INFINITY; + } + } + } + } +} + +template +inline __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + // const int row_idx_offset = row_idx_offset_ + lane_id / 4; + const int row_idx_offset = row_idx_offset_; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; +#pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; +#pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q); +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + // if (cute::thread0()) { + // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); + // print(tensor(make_coord(i, mi), _)); + // // print(tensor(_, j + nj * size<1, 0>(tensor))); + // } + } + } +} + +template +inline __device__ void apply_mask_causal_w_idx( + Tensor& tensor, Tensor const& idx_rowcol, + const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset_) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 2, "Only support 2D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); + CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); +#pragma unroll + for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { + if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { + tensor(mi, ni) = -INFINITY; + } + } + // if (cute::thread0()) { + // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); + // print(tensor(_, make_coord(j, ni))); + // // print(tensor(_, j + ni * size<1, 0>(tensor))); + // } + } +} + +} // namespace flash +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h new file mode 100644 index 0000000000..05ac247669 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h @@ -0,0 +1,60 @@ +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#define FP16_SWITCH(COND, ...) \ + [&] { \ + assert(COND); \ + using elem_type = cutlass::half_t; \ + return __VA_ARGS__(); \ + }() + +#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM <= 32) { \ + constexpr static int kHeadDim = 32; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 64) { \ + constexpr static int kHeadDim = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 96) { \ + constexpr static int kHeadDim = 96; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 128) { \ + constexpr static int kHeadDim = 128; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 160) { \ + constexpr static int kHeadDim = 160; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 192) { \ + constexpr static int kHeadDim = 192; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 224) { \ + constexpr static int kHeadDim = 224; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 256) { \ + constexpr static int kHeadDim = 256; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h new file mode 100644 index 0000000000..02042e183f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h @@ -0,0 +1,405 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include +#include + +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// +namespace onnxruntime { +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ uint32_t relu2(const uint32_t x); + +template <> +inline __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("max.f16x2 %0, %1, %2;\n" + : "=r"(res) + : "r"(x), "r"(zero)); +#else + asm volatile( + "{\n" + "\t .reg .f16x2 sela;\n" + "\t set.gtu.u32.f16x2 sela, %1, %2;\n" + "\t and.b32 %0, sela, %1;\n" + "}\n" + : "=r"(res) + : "r"(x), "r"(zero)); +#endif + return res; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +inline __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; + asm volatile("max.bf16x2 %0, %1, %2;\n" + : "=r"(res) + : "r"(x), "r"(zero)); + return res; +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +template +inline __device__ uint32_t convert_relu2(const float2 x); + +template <> +inline __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" + : "=r"(res) + : "r"(b), "r"(a)); + return res; +} + +template <> +inline __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" + : "=r"(res) + : "r"(b), "r"(a)); + return res; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { + __device__ inline T operator()(T const& x, T const& y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { + // This is slightly faster + __device__ inline float operator()(float const& x, float const& y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { + __device__ inline T operator()(T const& x, T const& y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator& op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Allreduce<2> { + template + static __device__ inline T run(T x, Operator& op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsA, + Tensor4 const& tCsB, TiledMma tiled_mma, + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + if (!A_in_regs) { + cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); + } + if (!B_in_regs) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + } +#pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + if (!A_in_regs) { + cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); + } + if (!B_in_regs) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm_A_in_regs(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); +#pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +template +inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting + // "int_tuple.hpp(74): error: conversion to inaccessible base class" + // return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. +template +inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { + using X = Underscore; + static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); + static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); + constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; + auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) + // TD [2023-08-13]: Same error as above on Cutlass 3.2 + // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), + // get<0, 1>(l), + // get<1, 1, 1>(l)); + return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))), + get<1>(get<0>(l)), + get<1>(get<1>(get<1>(l)))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ auto convert_type(Tensor const& tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast*>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void relu_(Tensor& tensor) { + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); + using value_t = typename Engine::value_type; + // HACK: this requires tensor to be "contiguous" + Tensor tensor_uint32 = recast(tensor); +#pragma unroll + for (int i = 0; i < size(tensor_uint32); ++i) { + tensor_uint32(i) = relu2(tensor_uint32(i)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction +template +inline __device__ auto convert_type_relu(Tensor const& tensor) { + using From_type = typename Engine::value_type; + static_assert(std::is_same_v || std::is_same_v); + static_assert(std::is_same_v); + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + // HACK: this requires tensor to be "contiguous" + Tensor tensor_float2 = recast(tensor); + Tensor out_uint32 = make_tensor(tensor_float2.layout()); +#pragma unroll + for (int i = 0; i < size(out_uint32); ++i) { + out_uint32(i) = convert_relu2(tensor_float2(i)); + } + Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); +#else + Tensor out = flash::convert_type(tensor); + flash::relu_(out); +#endif + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy(TiledCopy tiled_copy, Tensor const& S, + Tensor& D, Tensor const& identity_MN, + Tensor const& predicate_K, int max_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_2_sources(TiledCopy tiled_copy, Tensor const& S0, + Tensor const& S1, + Tensor& D, Tensor const& identity_MN, + Tensor const& predicate_K, + const int max_MN = 0, const int row_idx_switch = 0) { + CUTE_STATIC_ASSERT_V(rank(S0) == Int<3>{} && rank(S1) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S0) == size<0>(D) && size<0>(S1) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S0) == size<1>(D) && size<1>(S1) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S0) == size<2>(D) && size<2>(S1) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); +// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", Is_2_sources, max_MN, row_idx_switch); } +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", blockIdx.y, Is_2_sources, max_MN, row_idx_switch); } +#pragma unroll + for (int m = 0; m < size<1>(S0); ++m) { + auto& S = !Is_2_sources || get<0>(identity_MN(0, m, 0)) < row_idx_switch ? S0 : S1; + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S0); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_w_min_idx(Tensor const& S, + Tensor& D, Tensor const& identity_MN, + Tensor const& predicate_K, + const int max_MN = 0, const int min_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(S(_, m, k), D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc new file mode 100644 index 0000000000..f21dff08e0 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -0,0 +1,234 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "core/platform/env_var_utils.h" +#include "contrib_ops/cuda/bert/group_query_attention_impl.h" +#include "contrib_ops/cuda/bert/group_query_attention.h" +#include "contrib_ops/cuda/bert/group_query_attention_helper.h" +#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GroupQueryAttention, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", {DataTypeImpl::GetTensorType()}) \ + .MayInplace(3, 1) \ + .MayInplace(4, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 6), \ + GroupQueryAttention); + +// REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +template +GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) + : CudaKernel(info) { + int64_t num_heads = 0; + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); + num_heads_ = static_cast(num_heads); + kv_num_heads_ = static_cast(kv_num_heads); + is_unidirectional_ = true; + // left_padding_ = info.GetAttrOrDefault("left_padding_last_token", 0) == 1; + is_past_bsnh_ = false; // info.GetAttrOrDefault("is_past_bsnh", 1) == 1; + scale_ = info.GetAttrOrDefault("scale", 0.0f); + +#if USE_FLASH_ATTENTION + disable_flash_attention_ = sizeof(T) != 2 || + ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); +#else + disable_flash_attention_ = true; +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + disable_memory_efficient_attention_ = sizeof(T) != 2 || + ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); +#else + disable_memory_efficient_attention_ = true; +#endif +} + +template +Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { + const Tensor* query = context->Input(0); + const Tensor* key = context->Input(1); + const Tensor* value = context->Input(2); + const Tensor* past_key = context->Input(3); + const Tensor* past_value = context->Input(4); + const Tensor* seqlens_k = context->Input(5); + const Tensor* total_seqlen = context->Input(6); + + auto& device_prop = GetDeviceProp(); + GroupQueryAttentionParameters parameters; + typedef typename ToCudaType::MappedType CudaT; + GroupQueryAttentionData data; + + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, + key, + value, + past_key, + past_value, + ¶meters, + num_heads_, + kv_num_heads_, + seqlens_k, + total_seqlen, + is_past_bsnh_, + scale_, + device_prop.maxThreadsPerBlock)); + parameters.is_unidirectional = is_unidirectional_; + // parameters.left_padding = left_padding_; + int sequence_length = parameters.sequence_length; + + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(parameters.batch_size); + output_shape[1] = static_cast(sequence_length); + output_shape[2] = static_cast(parameters.hidden_size); + Tensor* output = context->Output(0, output_shape); + +#if USE_FLASH_ATTENTION + bool use_flash_attention = !disable_flash_attention_ && + onnxruntime::flash::is_supported(device_prop, + parameters.head_size, + parameters.num_heads, + parameters.kv_num_heads); + // Allocate buffers + size_t softmax_lse_bytes = 0; + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; + if (use_flash_attention) { + // softmax buffer + softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.sequence_length, parameters.batch_size, parameters.num_heads); + // split kv buffer + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( + parameters.batch_size, parameters.sequence_length, parameters.sequence_length, parameters.num_heads, + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + } + auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); +#else + constexpr bool use_flash_attention = false; + auto softmax_lse_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + int sm = (device_prop.major * 10) + device_prop.minor; + bool use_memory_efficient_attention = + !use_flash_attention && + !disable_memory_efficient_attention_ && + (parameters.head_size & 7) == 0 && + parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length && + (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && + has_memory_efficient_attention(sm, sizeof(T) == 2); + // allocate buffers + size_t kv_buffer_bytes = 0; + // need a buffer if we must ungroup kv + const bool needs_buff = (parameters.num_heads != parameters.kv_num_heads); + if (use_memory_efficient_attention && needs_buff) { + kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.seqlen_present_kv_cache * parameters.head_size); + } + size_t fmha_buffer_bytes = 0; + if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) { + fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float)); + } + auto k_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); + auto v_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); + auto fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, context->GetComputeStream()); +#else + constexpr bool use_memory_efficient_attention = false; + auto k_buffer = GetScratchBuffer(0, context->GetComputeStream()); + auto v_buffer = GetScratchBuffer(0, context->GetComputeStream()); + auto fmha_buffer = GetScratchBuffer(0, context->GetComputeStream()); +#endif + + // seqlens_k buffer + size_t seqlens_k_bytes = 0; + seqlens_k_bytes = sizeof(int) * parameters.batch_size; + auto seqlens_k_buffer = GetScratchBuffer(seqlens_k_bytes, context->GetComputeStream()); + + std::vector present_dims; + if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { + present_dims = { + parameters.batch_size, parameters.seqlen_present_kv_cache, parameters.kv_num_heads, parameters.head_size}; + } else { // BNSH + present_dims = { + parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, parameters.head_size}; + } + TensorShape present_shape(present_dims); + Tensor* present_key = context->Output(1, present_shape); + Tensor* present_value = context->Output(2, present_shape); + + data.query = reinterpret_cast(query->Data()); + data.key = reinterpret_cast(key->Data()); + data.value = reinterpret_cast(value->Data()); + data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data()); + data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast(past_value->Data()); + data.output = reinterpret_cast(output->MutableData()); + data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); + data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); + data.seqlens_k = const_cast(seqlens_k->Data()); + data.use_flash_attention = use_flash_attention; + data.use_memory_efficient_attention = use_memory_efficient_attention; + if (data.past_key == data.present_key) { + parameters.kv_share_buffer = true; + } else { + parameters.kv_share_buffer = false; + } + // Flash Buffers + if (softmax_lse_buffer != nullptr) { + data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); + } + if (softmax_lse_accum_buffer != nullptr) { + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + } + if (out_accum_buffer != nullptr) { + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } + if (seqlens_k_buffer != nullptr) { + data.seqlens_k_total = reinterpret_cast(seqlens_k_buffer.get()); + } + // Memory Efficient Buffers + if (k_buffer != nullptr) { + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + } + if (fmha_buffer != nullptr) { + data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); + } + if (k_buffer != nullptr) { + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + } + + cublasHandle_t cublas = GetCublasHandle(context); + + return QkvToContext( + device_prop, cublas, context->GetComputeStream(), parameters, data); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h new file mode 100644 index 0000000000..aade0436dc --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/providers/cuda/cuda_kernel.h" +#include "contrib_ops/cuda/bert/group_query_attention_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class GroupQueryAttention final : public CudaKernel { + public: + GroupQueryAttention(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + int num_heads_; // number of attention heads + int kv_num_heads_; // different for k and v for group query attention + // bool left_padding_; // shifts last token to end of buffer + bool is_unidirectional_; // causal + bool is_past_bsnh_; + float scale_; + bool disable_flash_attention_; + bool disable_memory_efficient_attention_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h new file mode 100644 index 0000000000..2cb9955807 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -0,0 +1,228 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/common.h" +#include "contrib_ops/cpu/bert/attention_common.h" + +namespace onnxruntime { +namespace contrib { +namespace group_query_attention_helper { + +Status CheckInputs(const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* past_key, + const Tensor* past_value, + void* parameters, + int num_heads, + int kv_num_heads, + const Tensor* seqlens_k, + const Tensor* total_seqlen, + bool is_past_bsnh, + float scale) { + // Note: Here S* is past_cache_sequence_length, S- is past_sequence_length, S+ is sequence_length + // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) + // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) + // no packing for q/k/v: + // query (Q) : (B, S, D) + // key (K) : (B, S, D_kv) + // value (V) : (B, S, D_kv) + ORT_UNUSED_PARAMETER(value); + + AttentionQkvFormat qkv_format = Q_K_V_BSNH; + AttentionQkvFormat past_kv_format = is_past_bsnh ? Q_K_V_BSNH : Q_K_V_BNSH; + + const auto& query_dims = query->Shape().GetDims(); + const auto& key_dims = key->Shape().GetDims(); + + if (query_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", + query_dims.size()); + } + + int batch_size = static_cast(query_dims[0]); + int sequence_length = static_cast(query_dims[1]); + int q_hidden_size = static_cast(query_dims[2]); + int head_size = static_cast(q_hidden_size) / num_heads; + + int kv_hidden_size = static_cast(key_dims[2]); + + int32_t past_sequence_length = 0; + if (past_key != nullptr && past_value != nullptr) { + const auto& past_key_dims = past_key->Shape().GetDims(); + const auto& past_value_dims = past_value->Shape().GetDims(); + + if (past_key_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' is expected to have 4 dimensions, got ", + past_key_dims.size()); + } + if (past_value_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' is expected to have 4 dimensions, got ", + past_value_dims.size()); + } + + if (past_key_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 0 should be batch_size, got ", + past_key_dims[0]); + } + if (past_value_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 0 should be batch_size, got ", + past_value_dims[0]); + } + + // BNSH + if (!is_past_bsnh) { + if (past_key_dims[2] != past_value_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "BNSH Input 'past_key' and 'past_value' should have same dimension 2 (max sequence" + "length or past sequence length), got ", + past_key_dims[1]); + } + if (past_key_dims[1] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' shall have kv_num_heads"); + } + if (past_value_dims[1] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' shall have kv_num_heads"); + } + // We assume all sequence in past kv are right-padded to max or past sequence length + past_sequence_length = static_cast(past_key_dims[2]); + // BSNH + } else { + if (past_key_dims[1] != past_value_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "BNSH Input 'past_key' and 'past_value' should have same dimension 1 (max sequence" + "length or past sequence length), got ", + past_key_dims[1]); + } + if (past_key_dims[2] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' shall have kv_num_heads"); + } + if (past_value_dims[2] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' shall have kv_num_heads"); + } + // We assume all sequence in past kv are right-padded to max or past sequence length + past_sequence_length = static_cast(past_key_dims[1]); + } + + if (past_key_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 3 should be same as head_size, got ", + past_key_dims[3]); + } + if (past_value_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 3 should be same as head_size, got ", + past_value_dims[3]); + } + } else if (past_key != nullptr || past_value != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' and 'past_value' shall be both present or both absent."); + } + + if (key_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", + key_dims.size()); + } + if (query_dims[0] != key_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 0 (batch size)"); + } + + if (num_heads % kv_num_heads != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", + num_heads % kv_num_heads); + } + + const auto& value_dims = value->Shape().GetDims(); + if (value_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", + value_dims.size()); + } + + if (query_dims[0] != value_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 0 (batch_size)"); + } + + if (static_cast(sequence_length) != value_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query,' 'key,' and 'value' shall have the same dim 1 (sequence_length)"); + } + + if (value_dims[2] != kv_hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); + } + + // Check seqlens_k tensor (holding past seqlen for token gen) + const auto& seqlens_dim = seqlens_k->Shape().GetDims(); + if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "seqlens_k must be shape (batch_size)."); + } + + // Set present sequence length and kv_share_buffer from input total_seqlen tensor + if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "total_sequence_length tensor must be of one element."); + } + int total_sequence_length = *((*total_seqlen).template Data()); + int present_sequence_length = std::max(total_sequence_length, past_sequence_length); + + bool is_prompt = sequence_length != 1; + + if (parameters != nullptr) { + GroupQueryAttentionParameters* output_parameters = reinterpret_cast(parameters); + output_parameters->batch_size = batch_size; + output_parameters->sequence_length = sequence_length; // sequence length of Q + output_parameters->seqlen_past_kv_cache = past_sequence_length; // max sequence length of past kv tensors + output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors + output_parameters->hidden_size = q_hidden_size; + output_parameters->num_heads = num_heads; + output_parameters->head_size = q_hidden_size / num_heads; + output_parameters->kv_hidden_size = kv_hidden_size; + output_parameters->kv_num_heads = kv_num_heads; + output_parameters->is_unidirectional = true; + output_parameters->is_prompt = is_prompt; + output_parameters->scale = scale; + output_parameters->qkv_format = qkv_format; + output_parameters->past_kv_format = past_kv_format; + } + + return Status::OK(); +} + +Status CheckInputs(const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* past_key, + const Tensor* past_value, + void* parameters, + int num_heads, + int kv_num_heads, + const Tensor* seqlens_k, + const Tensor* total_seqlen, + bool is_past_bsnh, + float scale, + int max_threads_per_block) { + if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); + } + + return CheckInputs(query, key, value, past_key, past_value, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale); +} + +} // namespace group_query_attention_helper +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu new file mode 100644 index 0000000000..2d158155ee --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -0,0 +1,773 @@ +/* + The implementation of this file is based on our Multi-Head Attention impl.cu file, + which is based on qkvToContext plugin in TensorRT demo: + https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ + +Copyright 2019 NVIDIA Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Modifications: +// (1) support GPT-2 past state, unidirectional mask (causal) +// (2) use flash attention kernel from (https://github.com/Dao-AILab/flash-attention) +// (3) support different number of heads for Q and KV +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/bert/attention_softmax.h" +#include "contrib_ops/cuda/bert/transformer_common.h" +#include "contrib_ops/cuda/bert/add_bias_transpose.h" +#include "contrib_ops/cpu/bert/attention_base.h" +#include "contrib_ops/cuda/bert/bert_padding.h" +#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cuda/bert/group_query_attention_impl.h" +#include "contrib_ops/cuda/bert/attention_impl.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" +#include + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +////////// Auxiliary Kernels for KV prep + +// Kernel for seqlens_k +__global__ void repeat_seqlen(int32_t* seqlens_k, int32_t seqlen, int batch_size) { + int id = blockDim.x * blockIdx.x + threadIdx.x; + if (id < batch_size) seqlens_k[id] = seqlen; +} + +// Kernel to append new and past kv in either BSNH or BNSH format +// Adapted from ConcatTensorToTensor kernel in attention_kv_cache.cu file +template +__global__ void ConcatNewToPastKV(const int new_seqlen, + const int past_buffer_seqlen, + const T* past_kv, + const T* new_kv, + T* present_kv, + const int* seqlens_k, + const bool is_bsnh) { // refers to past; otherwise bnsh + const int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int present_buffer_seqlen = gridDim.x; + const int num_heads = blockDim.y; + const int H = blockDim.x; + + const int present_batch_stride = present_buffer_seqlen * num_heads * H; + const int row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_buffer_seqlen * H; + + // past_kv: BPNH or BNPH + // new_kv: BLNH + // present_kv: BTNH or BNTH, where T = P + L + const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; + if (s < past_seqlen) { + const int past_batch_stride = past_buffer_seqlen * num_heads * H; + const int past_head_stride = is_bsnh ? H : past_buffer_seqlen * H; + const int in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; + present_kv[out_offset] = past_kv[in_offset]; + } else if (s < past_seqlen + new_seqlen) { + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; + present_kv[out_offset] = new_kv[in_offset]; + } +} + +// Use when (H*)*num_heads > 1024 +template +__global__ void ConcatNewToPastKVLarge(const int new_seqlen, + const int past_buffer_seqlen, + const int H, + const int num_heads, + const T* past_kv, + const T* new_kv, + T* present_kv, + const int* seqlens_k, + const bool is_bsnh) { + int i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i < H * num_heads) { + const int h = i % H; + const int n = i / H; + const int s = blockIdx.y; + const int b = blockIdx.z; + const int present_buffer_seqlen = gridDim.y; + + const int present_batch_stride = present_buffer_seqlen * num_heads * H; + const int row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_buffer_seqlen * H; + + // past_kv: BPNH or BNPH + // new_kv: BLNH + // present_kv: BTNH or BNTH, where T = P + L + const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; + if (s < past_seqlen) { + const int past_batch_stride = past_buffer_seqlen * num_heads * H; + const int past_head_stride = is_bsnh ? H : past_buffer_seqlen * H; + const int in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; + present_kv[out_offset] = past_kv[in_offset]; + } else if (s < past_seqlen + new_seqlen) { + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; + present_kv[out_offset] = new_kv[in_offset]; + } + } +} + +// Concat new to past in present. Supports past BSNH or past BNSH +template +Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = parameters.sequence_length; + const int past_sequence_length = parameters.seqlen_past_kv_cache; + const int present_sequence_length = parameters.seqlen_present_kv_cache; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + const int* seqlens_k = parameters.is_prompt ? nullptr : reinterpret_cast(data.seqlens_k); + + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + + assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + const int H = head_size / 4; // divide by 4 so kernel can operate on 4 float16 elements at a time. + if (H * kv_num_heads <= max_threads_per_block) { + const dim3 grid(present_sequence_length, batch_size, 1); + const dim3 block(H, kv_num_heads, 1); + ConcatNewToPastKV<<>>(kv_sequence_length, + past_sequence_length, + reinterpret_cast(data.past_key), + reinterpret_cast(data.key), + reinterpret_cast(data.present_key), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatNewToPastKV<<>>(kv_sequence_length, + past_sequence_length, + reinterpret_cast(data.past_value), + reinterpret_cast(data.value), + reinterpret_cast(data.present_value), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } else { + int steps = (H * kv_num_heads + 255) / 256; + const dim3 grid(steps, present_sequence_length, batch_size); + const dim3 block(256, 1, 1); + ConcatNewToPastKVLarge<<>>(kv_sequence_length, + past_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.past_key), + reinterpret_cast(data.key), + reinterpret_cast(data.present_key), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatNewToPastKVLarge<<>>(kv_sequence_length, + past_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.past_value), + reinterpret_cast(data.value), + reinterpret_cast(data.present_value), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } + return CUDA_CALL(cudaGetLastError()); +} + +// Kernel to append new kv to kv buffer in place +template +__global__ void ConcatKVInPlace(const int max_seqlen, + T* kv_buff, + const T* new_kv, + const int* seqlens_k, + const bool is_bsnh) { // refers to kv buff; otherwise bnsh + const int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int new_seqlen = gridDim.x; + const int num_heads = blockDim.y; + const int H = blockDim.x; + + const int present_batch_stride = max_seqlen * num_heads * H; + const int present_row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : max_seqlen * H; + + // kv_buff: BTNH or BNTH with buffered memory for new + // new_kv: BLNH + + const int past_seq_len = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + int out_offset = b * present_batch_stride + (s + past_seq_len) * present_row_stride + n * present_head_stride + h; + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h; + kv_buff[out_offset] = new_kv[in_offset]; +} + +template +__global__ void ConcatKVInPlaceLarge(const int max_seqlen, + const int H, + const int num_heads, + T* kv_buff, + const T* new_kv, + const int* seqlens_k, + const bool is_bsnh) { // refers to kv buff; otherwise bnsh + int i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i < H * num_heads) { + const int h = i % H; + const int n = i / H; + const int s = blockIdx.y; + const int b = blockIdx.z; + const int new_seqlen = gridDim.y; + + const int present_batch_stride = max_seqlen * num_heads * H; + const int present_row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : max_seqlen * H; + + // kv_buff: BTNH or BNTH with buffered memory for new + // new_kv: BLNH + + const int past_seq_len = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + int out_offset = b * present_batch_stride + (s + past_seq_len) * present_row_stride + n * present_head_stride + h; + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h; + kv_buff[out_offset] = new_kv[in_offset]; + } +} + +// Concat new to kv buffer in place +template +Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = parameters.sequence_length; + const int present_sequence_length = parameters.seqlen_present_kv_cache; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + + // Indicates past sequence_length of each sequence + const int* seqlens_k = parameters.is_prompt ? nullptr : reinterpret_cast(data.seqlens_k); + + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + const int H = head_size / 4; + if (H * kv_num_heads <= max_threads_per_block) { + const dim3 grid(kv_sequence_length, batch_size, 1); + const dim3 block(H, kv_num_heads, 1); + ConcatKVInPlace<<>>(present_sequence_length, + reinterpret_cast(data.present_key), + reinterpret_cast(data.key), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatKVInPlace<<>>(present_sequence_length, + reinterpret_cast(data.present_value), + reinterpret_cast(data.value), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } else { + int steps = int(ceil(float(H * kv_num_heads) / 256.0)); + const dim3 grid(steps, kv_sequence_length, batch_size); + const dim3 block(256, 1, 1); + ConcatKVInPlaceLarge<<>>(present_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.present_key), + reinterpret_cast(data.key), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatKVInPlaceLarge<<>>(present_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.present_value), + reinterpret_cast(data.value), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } + return CUDA_CALL(cudaGetLastError()); +} + +// Kernel for use with memory efficient kernel... kv_in is grouped and of bnsh or bsnh... kv_out is ungrouped and bsnh +template +__global__ void Ungroup(const T* kv_in, + T* kv_out, + const int in_seqlen, + const int kv_num_heads, + const bool is_bsnh) { + const int h = threadIdx.x; + const int out_n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int out_seqlen = gridDim.x; + const int q_num_heads = blockDim.y; + const int H = blockDim.x; + + const int q_kv_head_ratio = q_num_heads / kv_num_heads; + const int out_batch_stride = out_seqlen * q_num_heads * H; + const int out_row_stride = is_bsnh ? q_num_heads * H : H; + const int out_head_stride = is_bsnh ? H : out_seqlen * H; + + const int in_batch_stride = in_seqlen * kv_num_heads * H; + const int in_row_stride = is_bsnh ? kv_num_heads * H : H; + const int in_head_stride = is_bsnh ? H : in_seqlen * H; + const int in_n = out_n / q_kv_head_ratio; + + const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h; + const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h; + kv_out[out_offset] = kv_in[in_offset]; +} + +template +__global__ void UngroupLarge(const T* kv_in, + T* kv_out, + const int H, + const int in_seqlen, + const int q_num_heads, + const int kv_num_heads, + const bool is_bsnh) { + int i = threadIdx.x + (blockDim.x * blockIdx.x); // index along H * q_num_heads elements + if (i < H * q_num_heads) { + const int out_seqlen = gridDim.y; + const int s = blockIdx.y; + const int b = blockIdx.z; + + const int q_kv_head_ratio = q_num_heads / kv_num_heads; + const int out_batch_stride = out_seqlen * q_num_heads * H; + const int out_row_stride = is_bsnh ? q_num_heads * H : H; + const int out_head_stride = is_bsnh ? H : out_seqlen * H; + + const int in_batch_stride = in_seqlen * kv_num_heads * H; + const int in_row_stride = is_bsnh ? kv_num_heads * H : H; + const int in_head_stride = is_bsnh ? H : in_seqlen * H; + + const int h = i % H; + const int out_n = i / H; + const int in_n = out_n / q_kv_head_ratio; + const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h; + const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h; + kv_out[out_offset] = kv_in[in_offset]; + } +} + +// Ungroup kv or present kv for use in Memory Efficient kernel. If present kv is not null and is BNSH, transposes it. +Status LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters, + float2* k_buff, float2* v_buff, + const float2* k_og, const float2* v_og, + const int buff_seqlen, const int og_seqlen, + const bool is_bsnh, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + + const int H = head_size / 4; + if (H * num_heads <= max_threads_per_block) { + const dim3 grid(buff_seqlen, batch_size, 1); + const dim3 block(H, num_heads, 1); + Ungroup<<>>(k_og, + k_buff, + og_seqlen, + kv_num_heads, + is_bsnh); + Ungroup<<>>(v_og, + v_buff, + og_seqlen, + kv_num_heads, + is_bsnh); + } else { + int steps = int(ceil(float(H * num_heads) / 256.0)); + const dim3 grid(steps, buff_seqlen, batch_size); + const dim3 block(256, 1, 1); + UngroupLarge<<>>(k_og, + k_buff, + H, + og_seqlen, + num_heads, + kv_num_heads, + is_bsnh); + UngroupLarge<<>>(v_og, + v_buff, + H, + og_seqlen, + num_heads, + kv_num_heads, + is_bsnh); + } + return CUDA_CALL(cudaGetLastError()); +} + + +__global__ void PastToTotalSeqlen(int32_t* seqlens_k, + int32_t* seqlens_k_buff, + const int add_seqlen) { + seqlens_k_buff[threadIdx.x] = seqlens_k[threadIdx.x] + add_seqlen; +} + +// Convert Past to Total sequence length tensor +Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, + int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream, + const int threads_per_block) { + if (parameters.is_prompt) { + return Status::OK(); + } + const int batch_size = parameters.batch_size; + const int add_seqlen = is_total ? parameters.sequence_length : 0; + + const dim3 grid(1, 1, 1); + // TODO(aciddelgado): unlikely but could have a bigger batch_size than max_threads + const dim3 block(batch_size, 1, 1); + + // TODO(aciddelgado): small version + PastToTotalSeqlen<<>>(seqlens_k, seqlens_k_buff, add_seqlen); + + return CUDA_CALL(cudaGetLastError()); +} + +// // Kernel to append new kv to kv buffer in place +// template +// __global__ void LeftPadLast(const int max_seqlen, +// T* kv_buff, +// const int* seqlens_k) { // refers to kv buff; otherwise bnsh +// const int h = threadIdx.x; +// const int n = blockIdx.x; +// const int b = blockIdx.y; + +// const int num_heads = gridDim.x; +// const int H = blockDim.x; + +// const int present_batch_stride = max_seqlen * num_heads * H; +// const int present_row_stride = num_heads * H; +// const int present_head_stride = H; + +// // kv_buff: BTNH or BNTH with buffered memory for new +// // new_kv: BLNH + +// const int s = seqlens_k[b]; + +// const int in_offset = b * present_batch_stride + s * present_row_stride + n * present_head_stride + h; +// const int out_offset = b * present_batch_stride + (max_seqlen - 1) * present_row_stride + n * present_head_stride + h; +// kv_buff[out_offset] = kv_buff[in_offset]; +// } + +// // Concat new to kv buffer in place +// template +// Status LaunchLeftPadLast(contrib::GroupQueryAttentionParameters& parameters, +// GroupQueryAttentionData& data, +// cudaStream_t stream, +// const int max_threads_per_block) { +// const int batch_size = parameters.batch_size; +// const int sequence_length = parameters.sequence_length; +// const int num_heads = parameters.num_heads; +// const int head_size = parameters.head_size; + +// // Indicates past sequence_length of each sequence +// const int* seqlens_k = reinterpret_cast(data.seqlens_k); + +// const int H = head_size / 4; +// const dim3 grid(num_heads, batch_size, 1); +// const dim3 block(H, 1, 1); +// LeftPadLast<<>>(sequence_length, +// reinterpret_cast(data.output), +// seqlens_k); +// return CUDA_CALL(cudaGetLastError()); +// } + +////////// Launch Kernels + +#if USE_FLASH_ATTENTION +template +Status FlashAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + float scale) { + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.sequence_length; + const int present_sequence_length = parameters.seqlen_present_kv_cache; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + + void* query = reinterpret_cast(const_cast(data.query)); + void* key = reinterpret_cast(const_cast(data.key)); + void* value = reinterpret_cast(const_cast(data.value)); + + bool is_causal = parameters.is_unidirectional; + + // Note: seqlens_k is past sequence length for flash + if (parameters.is_prompt) { + // Launch kernel to copy seqlen + constexpr int thr_per_blk = 256; + int blk_in_grid = (batch_size + thr_per_blk -1) / thr_per_blk; + repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, batch_size); + } + + void* seqlens_k = reinterpret_cast(data.seqlens_k); + + if (parameters.kv_share_buffer) { + // Share buffer case + if (data.past_key == nullptr || data.past_key != data.present_key) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Past and present kv shall share the same tensor when kv_share_buffer is on."); + } + + if (parameters.is_prompt) { + ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); + key = nullptr; + value = nullptr; + seqlens_k = reinterpret_cast(data.seqlens_k_total); + } + + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("seqlens_k", reinterpret_cast(seqlens_k), batch_size, 1); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( + device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), + seqlens_k, batch_size, num_heads, kv_num_heads, + head_size, sequence_length, present_sequence_length, kv_sequence_length, + scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum))); + } else { + // Not share buffer case + // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient + if (data.past_key != nullptr && data.past_key == data.present_key) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Past and present kv share the same tensor but kv_share_buffer is not on."); + } + + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + + if (!parameters.is_prompt) { + ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256)); + } + + seqlens_k = reinterpret_cast(data.seqlens_k_total); + + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("seqlens_k", reinterpret_cast(seqlens_k), batch_size, 1); + DUMP_TENSOR("Q", data.query, batch_size, sequence_length, num_heads, head_size); + DUMP_TENSOR("K", data.present_key, batch_size, kv_num_heads, present_sequence_length, head_size); + DUMP_TENSOR("V", data.present_value, batch_size, kv_num_heads, present_sequence_length, head_size); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( + device_prop, stream, query, present_key, present_value, nullptr, nullptr, data.output, reinterpret_cast(data.softmax_lse), + seqlens_k, batch_size, num_heads, kv_num_heads, + head_size, sequence_length, present_sequence_length, 0, + scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum))); + } + + // if (parameters.left_padding && parameters.is_prompt) { + // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); + // } + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); + + return Status::OK(); +} +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION +template +Status EfficientAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + float scale) { + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int present_sequence_length = parameters.seqlen_present_kv_cache; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + + const void* query = reinterpret_cast(data.query); + const void* key = reinterpret_cast(data.key); + const void* value = reinterpret_cast(data.value); + + if (parameters.is_prompt) { + // Launch kernel to copy seqlen + constexpr int thr_per_blk = 256; + int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; + repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, batch_size); + } else { + ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256)); + } + + if (parameters.kv_share_buffer) { + // Share buffer case + if (data.past_key == nullptr || data.past_key != data.present_key) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Past and present kv shall share the same tensor when kv_share_buffer is on."); + } + // Concatenate new kv in place + ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); + } else { + // Not share buffer case + if (data.past_key != nullptr && data.past_key == data.present_key) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Past and present kv share the same tensor but kv_share_buffer is not on."); + } + // Copy past and concat new KV to present buffer + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + } + + // Ungroup if grouped, otherwise use present kv directly + const bool is_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + if (num_heads == kv_num_heads) { + // Use present kv directly if not grouped + key = reinterpret_cast(data.present_key); + value = reinterpret_cast(data.present_value); + } else { + // Otherwise we use intermediate buffers to run memory efficient attention... best avoid this path + float2* k_buff = reinterpret_cast(data.k); + float2* v_buff = reinterpret_cast(data.v); + const float2* k_og = reinterpret_cast(data.present_key); + const float2* v_og = reinterpret_cast(data.present_value); + ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, present_sequence_length, + present_sequence_length, is_bsnh, stream, max_threads_per_block)); + key = reinterpret_cast(data.k); + value = reinterpret_cast(data.v); + } + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("seqlens_k", data.seqlens_k_total, batch_size, 1); + + MemoryEfficientAttentionParams p; + p.sm = device_prop.major * 10 + device_prop.minor; + p.is_half = sizeof(T) == 2; + p.batch_size = batch_size; + p.num_heads = num_heads; + p.sequence_length = sequence_length; + p.kv_sequence_length = present_sequence_length; // TOTALLY UNNECESSARY IF WE HAVE SEQLENS_K, maybe remove + p.max_sequence_length = present_sequence_length; + p.qk_head_size = head_size; + p.v_head_size = head_size; + p.causal = parameters.is_unidirectional; + p.scale = scale; + p.seqlen_k_ptr = data.seqlens_k_total; // Note: seqlens_k is total sequence length for efficient + p.query = query; + p.key = key; + p.value = value; + p.attn_bias = nullptr; + p.is_attn_bias_batched = false; + p.is_kv_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + p.output = data.output; + p.workspace = MemoryEfficientAttentionParams::need_workspace(p.v_head_size, sizeof(T) == sizeof(float)) + ? data.fmha_buffer + : nullptr; + p.stream = stream; + p.has_custom_right_padding = true; + run_memory_efficient_attention(p); + + // if (parameters.left_padding && parameters.is_prompt) { + // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); + // } + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); + + return Status::OK(); +} +#endif + +////////// API Functions + +template +Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data) { + auto stream = static_cast(ort_stream->GetHandle()); + const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale; + +#if USE_FLASH_ATTENTION + if (data.use_flash_attention) { + return FlashAttention(device_prop, stream, parameters, data, scale); + } +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + if (data.use_memory_efficient_attention) { + return EfficientAttention(device_prop, stream, parameters, data, scale); + } +#endif + + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unfused Group Query Attention not implemented yet."); +} + +template struct GroupQueryAttentionData; + +template Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h new file mode 100644 index 0000000000..de32d7ea93 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" +#include +#include +#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/framework/allocator.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +struct GroupQueryAttentionData { + // Input Tensors + const T* query = nullptr; + const T* key = nullptr; + const T* value = nullptr; + const T* past_key = nullptr; + const T* past_value = nullptr; + int* seqlens_k = nullptr; + // Flash buffers + T* softmax_lse = nullptr; + T* softmax_lse_accum = nullptr; + T* out_accum = nullptr; + int* seqlens_k_total = nullptr; + // Memory Efficient buffers + T* fmha_buffer = nullptr; + T* k = nullptr; + T* v = nullptr; + // Output Tensors + T* output = nullptr; + T* present_key = nullptr; + T* present_value = nullptr; + // Kernel Flags + bool use_flash_attention = false; + bool use_memory_efficient_attention = false; +}; + +template +Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh b/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh index 5c083d64ee..ff3178b56c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh @@ -147,14 +147,16 @@ __device__ inline void LayerNormSmall(const T* input_v, const cub::KeyValuePair< __shared__ T rsigma; // 1 / std.dev. T beta_v[ILP], gamma_v[ILP], output_v[ILP]; - if (beta != nullptr) { - VecT* beta_val = reinterpret_cast(&beta_v); - *beta_val = *reinterpret_cast(&beta[threadIdx.x * ILP]); - } - VecT* gamma_val = reinterpret_cast(&gamma_v); - *gamma_val = *reinterpret_cast(&gamma[threadIdx.x * ILP]); + const bool is_valid = ILP * threadIdx.x < ld; + if (is_valid) { + if (beta != nullptr) { + VecT* beta_val = reinterpret_cast(&beta_v); + *beta_val = *reinterpret_cast(&beta[threadIdx.x * ILP]); + } - VecT* output_val = reinterpret_cast(&output_v); + VecT* gamma_val = reinterpret_cast(&gamma_v); + *gamma_val = *reinterpret_cast(&gamma[threadIdx.x * ILP]); + } KeyValuePairSum pair_sum; const cub::KeyValuePair sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); @@ -165,13 +167,15 @@ __device__ inline void LayerNormSmall(const T* input_v, const cub::KeyValuePair< } __syncthreads(); - if (ILP * threadIdx.x < ld) { + if (is_valid) { #pragma unroll for (int i = 0; i < ILP; i++) { output_v[i] = (beta != nullptr) ? gamma_v[i] * (input_v[i] - mu) * rsigma + beta_v[i] : gamma_v[i] * (input_v[i] - mu) * rsigma; } + + VecT* output_val = reinterpret_cast(&output_v); *(reinterpret_cast(&output[idx])) = *output_val; } } @@ -186,12 +190,15 @@ __device__ inline void SimplifiedLayerNormSmall(const T* input_v, const T& threa using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; __shared__ T rsigma; // 1 / std.dev. - T gamma_v[ILP], output_v[ILP]; - VecT* gamma_val = reinterpret_cast(&gamma_v); - *gamma_val = *reinterpret_cast(&gamma[threadIdx.x * ILP]); + const bool is_valid = ILP * threadIdx.x < ld; - VecT* output_val = reinterpret_cast(&output_v); + T gamma_v[ILP], output_v[ILP]; + + if (is_valid) { + VecT* gamma_val = reinterpret_cast(&gamma_v); + *gamma_val = *reinterpret_cast(&gamma[threadIdx.x * ILP]); + } const T sum = BlockReduce(temp_storage).Sum(thread_data); @@ -200,11 +207,13 @@ __device__ inline void SimplifiedLayerNormSmall(const T* input_v, const T& threa } __syncthreads(); - if (ILP * threadIdx.x < ld) { + if (is_valid) { #pragma unroll for (int i = 0; i < ILP; i++) { output_v[i] = gamma_v[i] * input_v[i] * rsigma; } + + VecT* output_val = reinterpret_cast(&output_v); *(reinterpret_cast(&output[idx])) = *output_val; } } diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu index de3c3fb6ca..f002394600 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu @@ -924,55 +924,55 @@ Status LongformerQkvToContext( if (disable_compact_memory) { ORT_RETURN_IF_ERROR(LaunchLongformerSoftmaxSimpleKernel( - stream, - cublas, - workspace, - q, - k, - v, - attention_mask, - global_q, - global_k, - global_v, - global_attention, - global_index, - batch_global_num, - pinned_buffer, - temp_output, - rsqrt_head_size, - batch_size, - sequence_length, - num_heads, - head_size, - window, - element_size)); + stream, + cublas, + workspace, + q, + k, + v, + attention_mask, + global_q, + global_k, + global_v, + global_attention, + global_index, + batch_global_num, + pinned_buffer, + temp_output, + rsqrt_head_size, + batch_size, + sequence_length, + num_heads, + head_size, + window, + element_size)); } else { ORT_ENFORCE(max_num_global <= window); ORT_RETURN_IF_ERROR(LaunchLongformerSoftmaxKernel( - stream, - cublas, - workspace, - q, - k, - v, - attention_mask, - max_num_global, - compact_global_q, - global_q, - global_k, - global_v, - global_attention, - global_index, - batch_global_num, - pinned_buffer, - temp_output, - rsqrt_head_size, - batch_size, - sequence_length, - num_heads, - head_size, - window, - element_size)); + stream, + cublas, + workspace, + q, + k, + v, + attention_mask, + max_num_global, + compact_global_q, + global_q, + global_k, + global_v, + global_attention, + global_index, + batch_global_num, + pinned_buffer, + temp_output, + rsqrt_head_size, + batch_size, + sequence_length, + num_heads, + head_size, + window, + element_size)); } // The temp_output is BxNxSxH, transpose it to final output BxSxNxH diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 15f0bc1a74..5f33b26cc4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -7,6 +7,7 @@ #include "contrib_ops/cuda/bert/multihead_attention.h" #include "contrib_ops/cpu/bert/multihead_attention_helper.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -51,6 +52,17 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false); #if USE_FLASH_ATTENTION + disable_flash_attention_ = sizeof(T) != 2 || + ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); + min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault( + attention::kMinSeqLenForFlashAttentionPackedQKV, + attention::kDefaultMinSeqLenForFlashAttentionPackedQKV); +#else + disable_flash_attention_ = true; + min_seq_len_for_flash_attention_packed_qkv_ = 0; +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); #else disable_memory_efficient_attention_ = true; @@ -118,9 +130,51 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { int sm = device_prop.major * 10 + device_prop.minor; bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN; - bool is_mask_1d_key_seq_len_start = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START; - bool use_fused_cross_attention = !disable_fused_cross_attention_ && + const bool pass_key_value_as_past = (parameters.pass_past_in_kv && nullptr != key && nullptr != value); + +#if USE_FLASH_ATTENTION || USE_MEMORY_EFFICIENT_ATTENTION + // Exclude this case since PrepareQkv will convert the format to BNSH. + bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; +#endif + +#if USE_FLASH_ATTENTION + bool use_flash_attention = !disable_flash_attention_ && + !past_no_bias && + nullptr == relative_position_bias && + nullptr == key_padding_mask && + parameters.head_size == parameters.v_head_size && + onnxruntime::flash::is_supported(device_prop, + parameters.head_size, + parameters.num_heads, + parameters.num_heads); + // When input is packed QKV format, TensorRT kernel might be faster than flash attention when sequence length <= 512. + if (use_flash_attention && key == nullptr && value == nullptr && + parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { + use_flash_attention = false; + } + // Allocate buffers + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; + if (use_flash_attention) { + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( + parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + } + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); +#else + constexpr bool use_flash_attention = false; + auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr +#endif + + bool use_fused_cross_attention = !use_flash_attention && + !disable_fused_cross_attention_ && nullptr == key_padding_mask && nullptr == relative_position_bias && (nullptr == past_key && nullptr == past_value && !parameters.pass_past_in_kv) && @@ -141,7 +195,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { } } - bool use_fused_runner = !disable_fused_self_attention_ && + bool use_fused_runner = !use_flash_attention && + !disable_fused_self_attention_ && fused_cross_attention_kernel == nullptr && nullptr == relative_position_bias && (value != nullptr || key == nullptr) && @@ -166,32 +221,30 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { } } - const bool pass_key_value_as_past = (parameters.pass_past_in_kv && nullptr != key && nullptr != value); - -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION bool is_long_sequence = sizeof(T) == 2 || // sequence length threshold is 0 for FP16 - parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32 || - parameters.kv_sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32; - - // Exclude this case since PrepareQkv will convert the format to BNSH. - bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; + parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32 || + parameters.kv_sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32; bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0; - bool use_memory_efficient_attention = fused_runner == nullptr && + bool use_memory_efficient_attention = !use_flash_attention && + fused_runner == nullptr && fused_cross_attention_kernel == nullptr && !disable_memory_efficient_attention_ && + (parameters.head_size & 7) == 0 && + (parameters.v_head_size & 7) == 0 && is_long_sequence && !past_no_bias && (relative_position_bias == nullptr || is_good_for_rpb) && - (nullptr == key_padding_mask || is_mask_1d_key_seq_len_start) && + (nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && has_memory_efficient_attention(sm, sizeof(T) == 2); #else constexpr bool use_memory_efficient_attention = false; - ORT_UNUSED_PARAMETER(is_mask_1d_key_seq_len_start); #endif // When packed kv or packed qkv is used, there is no needed for add bias transpose thus no qkv workspace. + // TODO(tianleiwu): flash attention or memory efficient attention might not need qkv workspace sometime. bool no_qkv_workspace = nullptr == value && (use_fused_cross_attention || (nullptr != fused_runner && nullptr == key)) && nullptr == key_padding_mask && @@ -211,6 +264,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.kv_sequence_length, parameters.total_sequence_length, fused_runner, + use_flash_attention, use_fused_cross_attention, use_memory_efficient_attention); } @@ -219,19 +273,18 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { const size_t past_k_bytes = element_size * parameters.batch_size * parameters.kv_sequence_length * parameters.num_heads * parameters.head_size; const size_t past_v_bytes = element_size * parameters.batch_size * parameters.kv_sequence_length * parameters.num_heads * parameters.v_head_size; - auto temp_k_work_space = (parameters.pass_past_in_kv || use_memory_efficient_attention) ? GetScratchBuffer(past_k_bytes, context->GetComputeStream()) : nullptr; - auto temp_v_work_space = (parameters.pass_past_in_kv || use_memory_efficient_attention) ? GetScratchBuffer(past_v_bytes, context->GetComputeStream()) : nullptr; + const bool use_temp_k_v_workspace = parameters.pass_past_in_kv || use_memory_efficient_attention || use_flash_attention; + auto temp_k_work_space = use_temp_k_v_workspace ? GetScratchBuffer(past_k_bytes, context->GetComputeStream()) : nullptr; + auto temp_v_work_space = use_temp_k_v_workspace ? GetScratchBuffer(past_v_bytes, context->GetComputeStream()) : nullptr; typedef typename ToCudaType::MappedType CudaT; AttentionData data; - data.gemm_buffer = nullptr; data.bias = (nullptr == bias) ? nullptr : reinterpret_cast(bias->Data()); data.query = reinterpret_cast(query->Data()); data.key = (nullptr == key || parameters.pass_past_in_kv) ? nullptr : reinterpret_cast(key->Data()); data.value = (nullptr == value || parameters.pass_past_in_kv) ? nullptr : reinterpret_cast(value->Data()); data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data(); data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span() : key_padding_mask->Shape().GetDims(); - data.past = nullptr; data.past_key = pass_key_value_as_past ? reinterpret_cast(key->Data()) : (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data()); @@ -241,17 +294,23 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); data.has_qkv_workspace = !no_qkv_workspace; data.workspace = reinterpret_cast(work_space.get()); - data.temp_k_workspace = (parameters.pass_past_in_kv || use_memory_efficient_attention) ? reinterpret_cast(temp_k_work_space.get()) : nullptr; - data.temp_v_workspace = (parameters.pass_past_in_kv || use_memory_efficient_attention) ? reinterpret_cast(temp_v_work_space.get()) : nullptr; + data.temp_k_workspace = use_temp_k_v_workspace ? reinterpret_cast(temp_k_work_space.get()) : nullptr; + data.temp_v_workspace = use_temp_k_v_workspace ? reinterpret_cast(temp_v_work_space.get()) : nullptr; data.output = reinterpret_cast(output->MutableData()); - data.present = nullptr; data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); data.fused_runner = reinterpret_cast(fused_runner); data.fused_cross_attention_kernel = fused_cross_attention_kernel; + data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; data.cumulated_sequence_length_q_cache = &(this->cumulated_sequence_length_q_cache_); data.cumulated_sequence_length_kv_cache = &(this->cumulated_sequence_length_kv_cache_); + if (softmax_lse_accum_buffer != nullptr) { + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + } + if (out_accum_buffer != nullptr) { + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index af5045e70d..33fa3d50e4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -28,7 +28,9 @@ class MultiHeadAttention final : public CudaKernel { bool disable_fused_self_attention_; bool enable_trt_flash_attention_; bool disable_fused_cross_attention_; + bool disable_flash_attention_; bool disable_memory_efficient_attention_; + int min_seq_len_for_flash_attention_packed_qkv_; mutable std::unique_ptr fused_fp16_runner_; mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_q_cache_; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc index 1b2c5f6200..ec8b1d051b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc @@ -283,7 +283,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { MHARunner* fused_runner = this->GetFusedRunner(device_prop, parameters); bool use_memory_efficient_attention = false; -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION if (nullptr == fused_runner) { int sm = device_prop.major * 10 + device_prop.minor; bool is_good_for_rpb = !parameters.has_relative_position_bias || parameters.sequence_length % (4 * sizeof(T)) == 0; @@ -324,6 +324,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { parameters.v_head_size, parameters.sequence_length, fused_runner, + false, use_memory_efficient_attention, no_qkv_workspace); auto work_space = this->GetScratchBuffer(workSpaceSize, context->GetComputeStream()); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index 5a99a98ce8..3b52320839 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -16,6 +16,7 @@ #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/rotary_embedding_util.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" using namespace onnxruntime::cuda; using namespace onnxruntime::contrib::attention_softmax_cuda; @@ -47,22 +48,32 @@ size_t GetAttentionWorkspaceSize( size_t v_head_size, size_t sequence_length, void* fused_runner, + bool use_flash_attention, bool use_memory_efficient_attention, bool no_qkv_workspace) { // Note that q, k and v might need alignment for fused attention kernels. const size_t qkv_bytes = no_qkv_workspace ? 0 : (element_size * batch_size * num_heads * sequence_length * (qk_head_size + qk_head_size + v_head_size)); +#if USE_FLASH_ATTENTION + // Use portion of workspace for softmax buffer. + if (use_flash_attention) { + size_t flash_buffer_bytes = onnxruntime::flash::get_softmax_lse_size(sequence_length, batch_size, num_heads); + return qkv_bytes + flash_buffer_bytes; + } +#else + ORT_UNUSED_PARAMETER(use_flash_attention); +#endif + if (fused_runner != nullptr) { return qkv_bytes; } -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION if (use_memory_efficient_attention) { size_t fmha_buffer_bytes = 0; if (MemoryEfficientAttentionParams::need_workspace(v_head_size, element_size == sizeof(float))) { fmha_buffer_bytes = batch_size * sequence_length * num_heads * v_head_size * sizeof(float); } - return qkv_bytes + fmha_buffer_bytes; } #else @@ -455,7 +466,7 @@ Status FusedScaledDotProductAttention( return Status::OK(); } -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION template Status FusedScaledDotProductAttentionCutlass( const cudaDeviceProp& device_prop, @@ -496,10 +507,12 @@ Status FusedScaledDotProductAttentionCutlass( MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; p.is_half = sizeof(T) == 2; + p.is_kv_bsnh = true; p.batch_size = parameters.batch_size; p.num_heads = parameters.num_heads; p.sequence_length = parameters.sequence_length; p.kv_sequence_length = parameters.sequence_length; + p.max_sequence_length = parameters.sequence_length; p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = false; @@ -516,6 +529,7 @@ Status FusedScaledDotProductAttentionCutlass( p.output = data.output; p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? accum_workspace : nullptr; p.stream = stream; + p.has_custom_right_padding = false; run_memory_efficient_attention(p); DUMP_TENSOR("PackedAttention cutlass output", data.output, parameters.token_count, num_heads, v_head_size); @@ -635,7 +649,7 @@ Status QkvToContext( return FusedScaledDotProductAttention(device_prop, stream, parameters, data); } -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION if (data.use_memory_efficient_attention) { return FusedScaledDotProductAttentionCutlass(device_prop, stream, parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.h index 9476bbed26..629ca59c73 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.h @@ -25,6 +25,7 @@ size_t GetAttentionWorkspaceSize( size_t v_head_size, size_t sequence_length, void* fused_runner, + bool use_flash_attention, bool use_memory_efficient_attention, bool no_qkv_workspace); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc index 8ffae86ae5..1b026e6477 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc @@ -9,6 +9,7 @@ #include "contrib_ops/cuda/bert/packed_multihead_attention_impl.h" #include "contrib_ops/cuda/bert/bert_padding.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -42,6 +43,17 @@ PackedMultiHeadAttention::PackedMultiHeadAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); #if USE_FLASH_ATTENTION + disable_flash_attention_ = sizeof(T) != 2 || onnxruntime::ParseEnvironmentVariableWithDefault( + attention::kDisableFlashAttention, false); + min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault( + attention::kMinSeqLenForFlashAttentionPackedQKV, + attention::kDefaultMinSeqLenForFlashAttentionPackedQKV); +#else + disable_flash_attention_ = true; + min_seq_len_for_flash_attention_packed_qkv_ = 0; +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION disable_memory_efficient_attention_ = onnxruntime::ParseEnvironmentVariableWithDefault( attention::kDisableMemoryEfficientAttention, false); #else @@ -94,8 +106,9 @@ Status PackedMultiHeadAttention::CheckInputs(const TensorShape& query_shape, int64_t v_hidden_size = hidden_size; if (query_dims.size() == 4) { if (key != nullptr || value != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' is expected to be empty when 'query' has 4 dimensions in packing mode"); + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' is expected to be empty when 'query' has 4 dimensions in packing mode"); } } else { // query_dims.size() == 2 if (key == nullptr) { @@ -143,11 +156,12 @@ Status PackedMultiHeadAttention::CheckInputs(const TensorShape& query_shape, const auto& cu_seq_len_dims = cu_seq_len_shape.GetDims(); if (cu_seq_len_dims.size() != 1 || cu_seq_len_dims[0] != batch_size + 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'cumulative_sequence_length' should have 1 dimension with size equal to batch_size + 1"); + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'cumulative_sequence_length' should have 1 dimension with size equal to batch_size + 1"); } - // TODO(tianleiwu): move relative position bias shape checker to a helper function. It is shared by multiple operators. + // TODO(tianleiwu): move relative position bias shape checker to a helper function. It is shared by multiple ops. const int num_heads = this->GetNumHeads(); bool broadcast_res_pos_bias = false; if (relative_position_bias != nullptr) { @@ -227,19 +241,39 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co Tensor* output = context->Output(0, output_shape); auto& device_prop = this->GetDeviceProp(); - MHARunner* fused_runner = this->GetFusedRunner(device_prop, parameters); + + bool use_flash_attention = false; +#if USE_FLASH_ATTENTION + if (!disable_flash_attention_) { + use_flash_attention = !parameters.has_relative_position_bias && + parameters.head_size == parameters.v_head_size && + onnxruntime::flash::is_supported(device_prop, + parameters.head_size, + parameters.num_heads, + parameters.num_heads); + + // When input is packed QKV format, TensorRT kernel might be faster when sequence length <= 512. + if (use_flash_attention && key == nullptr && value == nullptr && + parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { + use_flash_attention = false; + } + } +#endif + + MHARunner* fused_runner = use_flash_attention ? nullptr : this->GetFusedRunner(device_prop, parameters); bool use_memory_efficient_attention = false; -#if USE_FLASH_ATTENTION - if (nullptr == fused_runner && !disable_memory_efficient_attention_) { +#if USE_MEMORY_EFFICIENT_ATTENTION + if (!use_flash_attention && nullptr == fused_runner && !disable_memory_efficient_attention_) { int sm = device_prop.major * 10 + device_prop.minor; bool is_good_for_rpb = !parameters.has_relative_position_bias || parameters.sequence_length % (4 * sizeof(T)) == 0; - use_memory_efficient_attention = is_good_for_rpb && - (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32) && - (parameters.head_size & 7) == 0 && - (parameters.v_head_size & 7) == 0 && - has_memory_efficient_attention(sm, sizeof(T) == 2); + use_memory_efficient_attention = + is_good_for_rpb && + (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && + (parameters.head_size & 7) == 0 && + (parameters.v_head_size & 7) == 0 && + has_memory_efficient_attention(sm, sizeof(T) == 2); } #endif @@ -250,7 +284,9 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co constexpr size_t element_size = sizeof(T); // When the source and target format is same (like TN3H => TN3H, or TNH => TNH) and no bias, need not transpose qkv. const bool no_qkv_workspace = (fused_runner != nullptr && key == nullptr && bias == nullptr) || - (use_memory_efficient_attention && value != nullptr && bias == nullptr); + ((use_memory_efficient_attention || use_flash_attention) && + value != nullptr && + bias == nullptr); size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, parameters.batch_size, parameters.num_heads, @@ -258,6 +294,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co parameters.v_head_size, parameters.sequence_length, fused_runner, + use_flash_attention, use_memory_efficient_attention, no_qkv_workspace); auto work_space = this->GetScratchBuffer(workSpaceSize, context->GetComputeStream()); @@ -268,12 +305,15 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co data.key = (key == nullptr) ? nullptr : reinterpret_cast(key->Data()); data.value = (value == nullptr) ? nullptr : reinterpret_cast(value->Data()); data.bias = (bias == nullptr) ? nullptr : reinterpret_cast(bias->Data()); - data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); + data.relative_position_bias = (nullptr == relative_position_bias) + ? nullptr + : reinterpret_cast(relative_position_bias->Data()); data.workspace = reinterpret_cast(work_space.get()); data.token_offset = token_offset->Data(); data.cumulative_sequence_length = cumulative_sequence_length->Data(); data.output = reinterpret_cast(output->MutableData()); data.fused_runner = reinterpret_cast(fused_runner); + data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; data.no_qkv_workspace = no_qkv_workspace; data.source_qkv_format = (key == nullptr) ? AttentionQkvFormat::QKV_TN3H : AttentionQkvFormat::Q_K_V_TNH; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h index b59463a776..e30c603dc3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h @@ -31,6 +31,8 @@ class PackedMultiHeadAttention final : public TrtFusedAttention, public CudaK float scale_; // the scale for softmax in memory efficient attention or unfused attention. bool disable_memory_efficient_attention_; + bool disable_flash_attention_; + int min_seq_len_for_flash_attention_packed_qkv_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index d27cf975cb..8a508241d8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -17,6 +17,7 @@ #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/rotary_embedding_util.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" using namespace onnxruntime::cuda; using namespace onnxruntime::contrib::attention_softmax_cuda; @@ -32,7 +33,6 @@ static constexpr int32_t kMAX_THREADS_PER_BLOCK = 256; #define ADD_BIAS(value, bias_value) (biases == nullptr) ? value : (value + bias_value) #define GET_BIAS(bias_value) (biases == nullptr) ? T{} : bias_value - // Grid: (S, B) // Block: 256 // For unfused PackedMultiHeadAttention @@ -208,7 +208,6 @@ __global__ void TransposeQKV_TNH_TN3H( } } - // Grid: (S, B) // Block: 256 // For unfused PackedMultiHeadAttention @@ -329,7 +328,6 @@ __global__ void TransposeQKV_TN3H_3TNH( } } - // Grid: (T) // Block: 256 // For TRT fused attention. @@ -378,7 +376,6 @@ __global__ void AddBias_TN3H_TN3H( } } - template void InvokeTranspose( const T* query, const T* key, const T* value, const T* bias, T* output, @@ -587,6 +584,77 @@ Status FusedAttentionTrt( #if USE_FLASH_ATTENTION template +Status FlashAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + PackedAttentionParameters& parameters, + PackedMultiHeadAttentionData& data) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + // Q, K and V pointers + const int model_dimension_qk = num_heads * qk_head_size; + const int model_dimension_v = num_heads * v_head_size; + const size_t elements_qk = static_cast(parameters.token_count) * static_cast(model_dimension_qk); + const size_t elements_v = static_cast(parameters.token_count) * static_cast(model_dimension_v); + + // When separated Q, K, V is used, we can directly use them in Cutlass FMHA. Otherwise, transpose BSN3H to 3BSNH + if (!data.no_qkv_workspace) { + LaunchTranspose(data.query, data.key, data.value, data.bias, data.workspace, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + data.source_qkv_format, AttentionQkvFormat::Q_K_V_TNH, + data.token_offset, parameters.token_count, stream); + } + + float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) + : parameters.scale; + int32_t* cu_seqlens_q = const_cast(data.cumulative_sequence_length); + int32_t* cu_seqlens_k = const_cast(data.cumulative_sequence_length); + const void* query = data.no_qkv_workspace ? data.query : data.workspace; + const void* key = data.no_qkv_workspace ? data.key : (data.workspace + elements_qk); + const void* value = data.no_qkv_workspace ? data.value : (data.workspace + elements_qk + elements_qk); + void* softmax_lse_buffer = data.no_qkv_workspace + ? data.workspace + : (data.workspace + elements_qk + elements_qk + elements_v); + + ORT_RETURN_IF_ERROR( + onnxruntime::flash::mha_varlen_fwd( + device_prop, + stream, + const_cast(query), + const_cast(key), + const_cast(value), + data.output, + cu_seqlens_q, + cu_seqlens_k, + softmax_lse_buffer, + batch_size, + num_heads, + num_heads, // num_heads_k + qk_head_size, + sequence_length, + sequence_length, + scale, + false // is causal + )); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), parameters.token_count, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", reinterpret_cast(key), parameters.token_count, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", reinterpret_cast(value), parameters.token_count, num_heads, v_head_size); + DUMP_TENSOR_D("cumulative_sequence_length", data.cumulative_sequence_length, 1, batch_size + 1); + DUMP_TENSOR("PackedMHA flash output", data.output, parameters.token_count, num_heads, v_head_size); + + return Status::OK(); +} +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION +template Status FusedAttentionCutlass( const cudaDeviceProp& device_prop, cudaStream_t stream, @@ -620,6 +688,7 @@ Status FusedAttentionCutlass( p.num_heads = parameters.num_heads; p.sequence_length = parameters.sequence_length; p.kv_sequence_length = parameters.sequence_length; + p.max_sequence_length = parameters.sequence_length; p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = false; @@ -634,17 +703,19 @@ Status FusedAttentionCutlass( p.attn_bias = data.relative_position_bias; p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; p.output = data.output; + p.is_kv_bsnh = true; p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? (data.workspace + (data.no_qkv_workspace ? 0 : (elements_qk + elements_qk + elements_v))) : nullptr; p.stream = stream; + p.has_custom_right_padding = false; run_memory_efficient_attention(p); DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("PackedMHA cutlass q(BSNH)", reinterpret_cast(p.query), parameters.token_count, num_heads * qk_head_size); - DUMP_TENSOR_D("PackedMHA cutlass k(BSNH)", reinterpret_cast(p.key), parameters.token_count, num_heads * qk_head_size); - DUMP_TENSOR_D("PackedMHA cutlass v(BSNH)", reinterpret_cast(p.value), parameters.token_count, num_heads * v_head_size); - DUMP_TENSOR_D("PackedMHA cutlass cumulative_sequence_length", data.cumulative_sequence_length, 1, batch_size + 1); + DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(p.query), parameters.token_count, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", reinterpret_cast(p.key), parameters.token_count, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", reinterpret_cast(p.value), parameters.token_count, num_heads, v_head_size); + DUMP_TENSOR_D("cumulative_sequence_length", data.cumulative_sequence_length, 1, batch_size + 1); DUMP_TENSOR("PackedMHA cutlass output", data.output, parameters.token_count, num_heads, v_head_size); return Status::OK(); @@ -707,10 +778,10 @@ Status UnfusedAttention( // Q, K and V are ready now DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("PackedMHA unfused q (BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("PackedMHA unfused k (BNSH)", k, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("PackedMHA unfused v (BNSH)", v, batch_size, num_heads, sequence_length, v_head_size); - DUMP_TENSOR_D("PackedMHA unfused QK", scaled_qk, batch_size * num_heads, sequence_length, sequence_length); + DUMP_TENSOR_D("q (BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k (BNSH)", k, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("v (BNSH)", v, batch_size, num_heads, sequence_length, v_head_size); + DUMP_TENSOR_D("QK", scaled_qk, batch_size, num_heads, sequence_length, sequence_length); const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length); @@ -727,7 +798,7 @@ Status UnfusedAttention( num_heads, attention_score, stream)); - DUMP_TENSOR_D("PackedMHA unfused Softmax", attention_score, batch_size * num_heads, sequence_length, sequence_length); + DUMP_TENSOR_D("Softmax", attention_score, batch_size, num_heads, sequence_length, sequence_length); // compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v T* temp_output = qkv; @@ -762,6 +833,12 @@ Status QkvToContext( } #if USE_FLASH_ATTENTION + if (data.use_flash_attention) { + return FlashAttention(device_prop, stream, parameters, data); + } +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION if (data.use_memory_efficient_attention) { return FusedAttentionCutlass(device_prop, stream, parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.h index c7b7280878..eeca72f16e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.h @@ -29,6 +29,7 @@ struct PackedMultiHeadAttentionData { void* fused_runner; + bool use_flash_attention; bool use_memory_efficient_attention; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc new file mode 100644 index 0000000000..b4b5dac1fb --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cpu/bert/rotary_embedding_helper.h" +#include "contrib_ops/cuda/bert/rotary_embedding.h" +#include "contrib_ops/cuda/bert/rotary_embedding_impl.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::contrib::rotary_embedding_helper; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + RotaryEmbedding, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()), \ + RotaryEmbedding); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +template +RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) { + scale = info.GetAttrOrDefault("scale", 1.0); + interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); +} + +template +Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* position_ids = context->Input(1); + const Tensor* cos_cache = context->Input(2); + const Tensor* sin_cache = context->Input(3); + + RotaryParameters parameters = {}; + ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs(input, + position_ids, + cos_cache, + sin_cache, + ¶meters)); + + Tensor* output = context->Output(0, input->Shape()); + + if (parameters.sequence_length > parameters.max_sequence_length) { + // Launch update_cos_sin_cache kernel with scale + ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); + } + + // Launch rotary embedding kernel + typedef typename ToCudaType::MappedType CudaT; + auto& device_prop = GetDeviceProp(); + return LaunchRotaryEmbeddingKernel( + Stream(context), + reinterpret_cast(output->template MutableData()), + reinterpret_cast(input->template Data()), + position_ids->Data(), + reinterpret_cast(cos_cache->template Data()), + reinterpret_cast(sin_cache->template Data()), + parameters.batch_size, + parameters.sequence_length, + parameters.num_heads, + parameters.head_size, + parameters.max_sequence_length, + parameters.position_ids_format, + interleaved, + device_prop.maxThreadsPerBlock); + + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h new file mode 100644 index 0000000000..6dab2ad567 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class RotaryEmbedding final : public CudaKernel { + public: + RotaryEmbedding(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + float scale; + bool interleaved; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu new file mode 100644 index 0000000000..c54e72dcfc --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -0,0 +1,141 @@ +/* +Copyright (c) Microsoft Corporation. +Licensed under the MIT License. +*/ + +/* +Kernel implementation for rotary embeddings. +*/ + +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "contrib_ops/cuda/bert/rotary_embedding_impl.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH + const T* input, // BxSxNxH + const T* cos_cache, // Mx(H/2) + const T* sin_cache, // Mx(H/2) + const int64_t* position_ids, // (1) or BxS + const int sequence_length, + const int num_heads, + const int head_size, + const int position_ids_format, + const bool interleaved) { + // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length + // Use .x in innermost loop to access global memory efficiently + + const int b = blockIdx.z; + const int s = blockIdx.y; + const int n = blockIdx.x; + + const int i = threadIdx.x; + + const int block_offset = b * sequence_length * num_heads + s * num_heads + n; + const int data_offset = block_offset * head_size; + + const T* input_data = input + data_offset; + T* output_data = output + data_offset; + + // Cache is (M, H/2) + const int half_head_size = head_size / 2; + const int position_id = (position_ids_format == 0) ? \ + static_cast(position_ids[0]) + s \ + : static_cast(position_ids[b * sequence_length + s]); + const int cache_offset = position_id * half_head_size; + const T* cos_data = cos_cache + cache_offset; + const T* sin_data = sin_cache + cache_offset; + + int cache_idx = 0; + T sign = 0; + int j = 0; + if (interleaved) { + cache_idx = (i / 2) % half_head_size; + sign = (i % 2 == 0) ? -1 : 1; + j = (i % 2 == 0) ? i+1 : i-1; // i - sign + } else { + cache_idx = i % half_head_size; + sign = (i < half_head_size) ? -1 : 1; + j = (i + half_head_size) % head_size; + } + output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; +} + + +template +Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + T* output, + const T* input, + const int64_t* position_ids, + const T* cos_cache, + const T* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block) { + + constexpr int smem_size = 0; + const dim3 grid(num_heads, sequence_length, batch_size); + const dim3 block(head_size, 1, 1); + + // Note: Current implementation assumes head_size <= max_threads_per_block + // because head_size is currently large for LLaMA-2. For smaller head_size + // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` + // instead. This will require kernel changes to support. + + assert(head_size <= max_threads_per_block); + RotaryEmbeddingBSNH<<>>( + output, input, cos_cache, sin_cache, position_ids, + sequence_length, num_heads, head_size, position_ids_format, interleaved + ); + + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + float* output, + const float* input, + const int64_t* position_ids, + const float* cos_cache, + const float* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block); + +template Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + half* output, + const half* input, + const int64_t* position_ids, + const half* cos_cache, + const half* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block); + + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h new file mode 100644 index 0000000000..29ff48a8ad --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + T* output, + const T* input, + const int64_t* position_ids, + const T* cos_cache, + const T* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index 78174181ac..3299bc2cb1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -3,6 +3,7 @@ #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/nn/layer_norm_impl.h" +#include "core/common/narrow.h" #include "skip_layer_norm.h" #include "skip_layer_norm_impl.h" #include "contrib_ops/cpu/skip_layer_norm_helper.h" @@ -50,6 +51,11 @@ template Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { const Tensor* input = ctx->Input(0); const Tensor* skip = ctx->Input(1); + if (strict_ && skip->Shape() != input->Shape()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "'input' and 'skip' shall have same shape when enable_skip_layer_norm_strict_mode is True"); + } + const Tensor* gamma = ctx->Input(2); const Tensor* beta = Simplified ? nullptr : ctx->Input(3); @@ -57,16 +63,13 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const Tensor* output = ctx->Output(0, input->Shape()); - // For inferencing, we support one more optional output which is the sum - // of the input and skip tensors - Tensor* skip_input_bias_add_output = ctx->Output(3, input->Shape()); + // Optional output for the sum of skip, input and bias tensors (It is also the input of Layer Normalization). + Tensor* sum_output = ctx->Output(3, input->Shape()); const auto& input_dims = input->Shape().GetDims(); size_t input_dims_size = input_dims.size(); - const auto& skip_dims = skip->Shape().GetDims(); - size_t skip_dims_size = skip_dims.size(); - int hidden_size = static_cast(input_dims[input_dims_size - 1]); + int hidden_size = onnxruntime::narrow(input_dims[input_dims_size - 1]); ORT_RETURN_IF_ERROR(onnxruntime::contrib::skip_layer_norm_helper::CheckInputs(input, skip, @@ -76,12 +79,15 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const hidden_size, input_dims_size)); - const bool skip_broadcasted = (skip_dims[0] == 1 || skip_dims_size == 2) ? true : false; - const int skip_size = static_cast(skip_dims[skip_dims_size - 1] * skip_dims[skip_dims_size - 2]); + int row_count = onnxruntime::narrow(input->Shape().SizeToDimension(input_dims_size - 1)); + if (row_count == 0) { + return Status::OK(); + } - int row_count = gsl::narrow(input->Shape().SizeToDimension(input_dims_size - 1)); typedef typename ToCudaType::MappedType CudaT; + const int skip_size = onnxruntime::narrow(skip->Shape().Size()); + if (strict_) { HostApplyLayerNorm( GetDeviceProp(), @@ -97,21 +103,20 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, // beta reinterpret_cast(skip->Data()), // skip or residual to add (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, // bias to add - skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr); + sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr); } else { LaunchSkipLayerNormKernel( Stream(ctx), reinterpret_cast(output->MutableData()), - skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr, + sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr, reinterpret_cast(input->Data()), reinterpret_cast(skip->Data()), + (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, reinterpret_cast(gamma->Data()), (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, - (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, epsilon_, hidden_size, row_count, - skip_broadcasted, skip_size); } diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu index f2ee076a8a..50c8e4b5e0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -51,61 +51,68 @@ half maybe2half(float x) { // Using only power of 2 numbers will lead to waste of compute for same size such as 768, which is a very common case // in BERT. Ideally we can step by wrap_size * num_unroll, but listing too many steps will cause long compile time. -constexpr int kSizes[] = {32, 64, 128, 384, 768, 1024, 2048}; +constexpr int kSizes[] = {128, 320, 384, 640, 768, 1024, 1280, 2048, 4096, 5120, 8192}; +constexpr size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); +constexpr int kMaxSize = kSizes[kNumOfSizes - 1]; constexpr int kMinBlockSize = 32; -constexpr int kMaxBlockSize = 256; +constexpr int kMaxBlockSize = 1024; int NextSize(int x) { - size_t len = sizeof(kSizes) / sizeof(kSizes[0]); - for (size_t i = 0; i < len; ++i) { + for (size_t i = 0; i < kNumOfSizes; ++i) { if (x <= kSizes[i]) { return kSizes[i]; } } - return kSizes[len - 1]; + return kMaxSize + 1; } -template -bool CanVectorized(T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma, - const T* beta, const T* bias, const int ld, const int next_size) { - constexpr int alignment = std::alignment_of>::value; - return ld % NumUnroll == 0 && reinterpret_cast(output) % alignment == 0 && - reinterpret_cast(skip_input_bias_add_output) % alignment == 0 && - reinterpret_cast(input) % alignment == 0 && reinterpret_cast(skip) % alignment == 0 && - reinterpret_cast(gamma) % alignment == 0 && reinterpret_cast(beta) % alignment == 0 && - reinterpret_cast(bias) % alignment == 0 && next_size / NumUnroll >= kMinBlockSize && - next_size / NumUnroll <= kMaxBlockSize; +bool CanVectorized(void* output, void* sum_output, const void* input, const void* skip, const void* bias, + const void* gamma, const void* beta, const int ld, const int next_size, int num_unroll, int element_size) { + int alignment = element_size * num_unroll; + return ld % num_unroll == 0 && + reinterpret_cast(output) % alignment == 0 && + reinterpret_cast(sum_output) % alignment == 0 && + reinterpret_cast(input) % alignment == 0 && + reinterpret_cast(skip) % alignment == 0 && + reinterpret_cast(bias) % alignment == 0 && + reinterpret_cast(gamma) % alignment == 0 && + reinterpret_cast(beta) % alignment == 0 && + next_size / num_unroll >= kMinBlockSize && + next_size / num_unroll <= kMaxBlockSize; } } // namespace template __global__ void SkipLayerNormKernel( - const int ld, const T* input, const T* skip, - const T* beta, const T* gamma, const T* bias, - const T epsilon, T* output, T* skip_input_bias_add_output, const bool skip_broadcasted, int skip_size) { + T* output, T* sum_output, const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, T epsilon, + const int ld, int skip_size) { const T reverse_ld = T(1.f / ld); const int offset = blockIdx.x * ld; + const bool has_bias = (bias != nullptr); + // Reduce sum of x and x^2, and the results are divided by ld. KeyValuePairSum pair_sum; - // reduce x and x^2 cub::KeyValuePair thread_data(0, 0); - for (int i = threadIdx.x; i < ld; i += TPB) { const int idx = offset + i; - const T skip_data = skip_broadcasted ? skip[idx % skip_size] : skip[idx]; - const T val = (bias == nullptr) ? input[idx] + skip_data : input[idx] + skip_data + bias[i]; + T val = input[idx]; + if (has_bias) { + val += bias[i]; + } + val += skip[idx % skip_size]; const T rldval = reverse_ld * val; thread_data = pair_sum(thread_data, cub::KeyValuePair(rldval, rldval * val)); - if (skip_input_bias_add_output != nullptr) { - skip_input_bias_add_output[idx] = val; + if (sum_output != nullptr) { + sum_output[idx] = val; } output[idx] = val; } + if (Simplified) { SimplifiedLayerNorm(thread_data.value, ld, offset, gamma, epsilon, output); return; @@ -116,106 +123,114 @@ __global__ void SkipLayerNormKernel( // Vectorized kernel template __global__ void SkipLayerNormKernelSmall( - const int ld, const T* input, const T* skip, const T* beta, const T* gamma, - const T* bias, const T epsilon, T* output, T* skip_input_bias_add_output, - bool hasBias, bool hasSkipInputBiasAdditionOutput, const bool skip_broadcasted, const int skip_size) { + T* output, T* sum_output, const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, T epsilon, + int ld, int skip_size) { const T rld = T(1.f / ld); - const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld + const int idx = blockIdx.x * ld + threadIdx.x * ILP; using VecT = aligned_vector; + T sum_v[ILP]; - T input_v[ILP], skip_v[ILP], bias_v[ILP], skip_input_bias_add_output_v[ILP]; + cub::KeyValuePair thread_data(T(0.f), T(0.f)); - VecT* input_val = reinterpret_cast(&input_v); - *input_val = *reinterpret_cast(&input[idx]); + if (ILP * threadIdx.x < ld) { // load data under this guard to avoid reading out-of-bounds + T skip_v[ILP], bias_v[ILP]; - VecT* skip_val = reinterpret_cast(&skip_v); - if (skip_broadcasted){ - *skip_val = *reinterpret_cast(&skip[idx % skip_size]); - }else{ - *skip_val = *reinterpret_cast(&skip[idx]); - } + // load input to sum_v + VecT* sum_val = reinterpret_cast(&sum_v); + *sum_val = *reinterpret_cast(&input[idx]); - if (hasBias) { - VecT* bias_val = reinterpret_cast(&bias_v); - *bias_val = *reinterpret_cast(&bias[threadIdx.x * ILP]); - } + VecT* skip_val = reinterpret_cast(&skip_v); + *skip_val = *reinterpret_cast(&skip[idx % skip_size]); - cub::KeyValuePair thread_data(T(0.f), T(0.f)); + const bool has_bias = (bias != nullptr); + if (has_bias) { + VecT* bias_val = reinterpret_cast(&bias_v); + *bias_val = *reinterpret_cast(&bias[threadIdx.x * ILP]); + } - if (ILP * threadIdx.x < ld) { T rldval_sum = T(0.f); T rldvalsq_sum = T(0.f); + const bool has_sum_output = (sum_output != nullptr); + #pragma unroll for (int i = 0; i < ILP; i++) { - input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i]; - - if (hasSkipInputBiasAdditionOutput) { - skip_input_bias_add_output_v[i] = input_v[i]; + if (has_bias) { + sum_v[i] += bias_v[i]; } + sum_v[i] += skip_v[i]; - const T rldval = rld * input_v[i]; + const T rldval = rld * sum_v[i]; rldval_sum += rldval; - rldvalsq_sum += rldval * input_v[i]; + rldvalsq_sum += rldval * sum_v[i]; } - if (hasSkipInputBiasAdditionOutput) { - *(reinterpret_cast(&skip_input_bias_add_output[idx])) = *reinterpret_cast(&skip_input_bias_add_output_v); + if (has_sum_output) { + *(reinterpret_cast(&sum_output[idx])) = *reinterpret_cast(&sum_v); } thread_data = cub::KeyValuePair(rldval_sum, rldvalsq_sum); } if (Simplified) { - SimplifiedLayerNormSmall(input_v, thread_data.value, ld, idx, gamma, epsilon, output); + SimplifiedLayerNormSmall(sum_v, thread_data.value, ld, idx, gamma, epsilon, output); return; } - LayerNormSmall(input_v, thread_data, ld, idx, beta, gamma, epsilon, output); + LayerNormSmall(sum_v, thread_data, ld, idx, beta, gamma, epsilon, output); } template void LaunchSkipLayerNormKernel( - cudaStream_t stream, T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma, - const T* beta, const T* bias, float epsilon, int ld, int row_count, bool skip_broadcasted, int skip_size) { - if (row_count == 0) { - return; - } - - bool hasBias = (bias == nullptr) ? false : true; - bool hasSkipInputBiasAdditionOutput = (skip_input_bias_add_output == nullptr) ? false : true; - + cudaStream_t stream, T* output, T* sum_output, + const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, float epsilon, + int ld, int row_count, int skip_size) { const int next_size = NextSize(ld); const int grid_size = row_count; - bool flag_vec2 = - CanVectorized(output, skip_input_bias_add_output, input, skip, gamma, beta, bias, ld, next_size); - bool flag_vec4 = - CanVectorized(output, skip_input_bias_add_output, input, skip, gamma, beta, bias, ld, next_size); + bool can_unroll_vec4 = CanVectorized(output, sum_output, input, + skip, bias, gamma, + beta, ld, next_size, + 4, sizeof(T)); + bool can_unroll_vec8 = CanVectorized(output, sum_output, input, + skip, bias, gamma, + beta, ld, next_size, + 8, sizeof(T)); + +#define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \ + SkipLayerNormKernelSmall<<>>( \ + output, sum_output, input, skip, bias, gamma, beta, maybe2half(epsilon), ld, skip_size) + +#define LAUNCH_SKIP_LAYER_NORM_KERNEL() \ + SkipLayerNormKernel<<>>( \ + output, sum_output, input, skip, bias, gamma, beta, maybe2half(epsilon), ld, skip_size) + +#define CASE_NEXT_SIZE(next_size_value) \ + case next_size_value: { \ + static_assert(next_size_value >= kSizes[0] && next_size_value <= kMaxSize); \ + if constexpr (next_size_value >= 320) { \ + if (can_unroll_vec8) { \ + constexpr int block_size = next_size_value / 8; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(8); \ + } else { \ + constexpr int block_size = 256; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ + } \ + } else { \ + if (can_unroll_vec4) { \ + constexpr int block_size = next_size_value / 4; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(4); \ + } else { \ + if (next_size_value <= kMaxBlockSize) { \ + constexpr int block_size = next_size_value; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); \ + } else { \ + constexpr int block_size = 256; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ + } \ + } \ + } \ + } break switch (next_size) { -#define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \ - SkipLayerNormKernelSmall \ - <<>>(ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, \ - skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput, skip_broadcasted, skip_size) -#define LAUNCH_SKIP_LAYER_NORM_KERNEL() \ - SkipLayerNormKernel<<>>( \ - ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, skip_input_bias_add_output, skip_broadcasted, skip_size) -#define CASE_NEXT_SIZE(next_size_value) \ - case next_size_value: { \ - if (flag_vec4) { \ - constexpr int block_size = next_size_value / 4; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(4); \ - } else if (flag_vec2) { \ - constexpr int block_size = next_size_value / 2; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(2); \ - } else { \ - if (next_size_value <= kMaxBlockSize) { \ - constexpr int block_size = next_size_value; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); \ - } else { \ - LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ - } \ - } \ - } break CASE_NEXT_SIZE(kSizes[0]); CASE_NEXT_SIZE(kSizes[1]); CASE_NEXT_SIZE(kSizes[2]); @@ -223,18 +238,27 @@ void LaunchSkipLayerNormKernel( CASE_NEXT_SIZE(kSizes[4]); CASE_NEXT_SIZE(kSizes[5]); CASE_NEXT_SIZE(kSizes[6]); + CASE_NEXT_SIZE(kSizes[7]); + CASE_NEXT_SIZE(kSizes[8]); + CASE_NEXT_SIZE(kSizes[9]); + CASE_NEXT_SIZE(kSizes[10]); + default: { + constexpr int block_size = 256; + LAUNCH_SKIP_LAYER_NORM_KERNEL(); + break; + } + } + #undef CASE_NEXT_SIZE #undef LAUNCH_SKIP_LAYER_NORM_KERNEL #undef LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL - } } -#define SKIPLAYERNORM_IMPL(T, Simplified) \ - template void LaunchSkipLayerNormKernel(cudaStream_t stream, T * output, \ - T * skip_input_bias_add_output, \ - const T* input, const T* skip, const T* gamma, \ - const T* beta, const T* bias, float epsilon, \ - int ld, int row_count, bool skip_broadcasted, int skip_size); +#define SKIPLAYERNORM_IMPL(T, Simplified) \ + template void LaunchSkipLayerNormKernel(cudaStream_t stream, T * output, T * sum_output, \ + const T* input, const T* skip, const T* bias, \ + const T* gamma, const T* beta, float epsilon, \ + int ld, int row_count, int skip_size); SKIPLAYERNORM_IMPL(float, true); SKIPLAYERNORM_IMPL(float, false); SKIPLAYERNORM_IMPL(half, true); diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h index ffb5850c82..9727dd6236 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h @@ -11,18 +11,17 @@ namespace cuda { template void LaunchSkipLayerNormKernel( cudaStream_t stream, - T* output, // normalized output tensor - T* skip_input_bias_add_output, // sum of the input and skip (and bias if it exists) tensors output - const T* input, // input tensor - const T* skip, // skip tensor - const T* gamma, // Layer normalization gamma tensor - const T* beta, // Layer normalization beta tensor - const T* bias, // Layer normalization beta tensor - float epsilon, // Layer normalization epsilon - int hidden_size, // hidden size, it is the leading dimension (ld) - int row_count, // number of rows. That is total number of elements divided by hidden size. - bool skip_broadcasted, // determines if broadcasting should be implemented - int skip_size); // determines size of the skip tensor + T* output, // normalized output tensor + T* sum_output, // sum of the input and skip (and bias if it exists) tensors output + const T* input, // input tensor + const T* skip, // skip tensor + const T* bias, // bias tensor + const T* gamma, // Layer normalization gamma tensor + const T* beta, // Layer normalization beta tensor + float epsilon, // Layer normalization epsilon + int hidden_size, // hidden size, it is the leading dimension (ld) + int row_count, // number of rows. That is total number of elements divided by hidden size. + int skip_size); // number of elements of the skip tensor } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/fused_multihead_attention_v2.h b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/fused_multihead_attention_v2.h index d61501f429..ce42e33ba1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/fused_multihead_attention_v2.h +++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/fused_multihead_attention_v2.h @@ -855,6 +855,139 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { false, false}, + {DATA_TYPE_FP16, + 32, + 32, + kSM_80, + cubin_fmha_v2_fp16_32_32_sm80_cu_cubin, + cubin_fmha_v2_fp16_32_32_sm80_cu_cubin_len, + "fmha_v2_fp16_32_32_sm80_kernel", + 8192, + 128, + 0, + false, + false}, + {DATA_TYPE_FP16, + 64, + 32, + kSM_80, + cubin_fmha_v2_fp16_64_32_sm80_cu_cubin, + cubin_fmha_v2_fp16_64_32_sm80_cu_cubin_len, + "fmha_v2_fp16_64_32_sm80_kernel", + 16384, + 128, + 0, + false, + false}, + {DATA_TYPE_FP16, + 96, + 32, + kSM_80, + cubin_fmha_v2_fp16_96_32_sm80_cu_cubin, + cubin_fmha_v2_fp16_96_32_sm80_cu_cubin_len, + "fmha_v2_fp16_96_32_sm80_kernel", + 24576, + 128, + 0, + false, + false}, + {DATA_TYPE_FP16, + 128, + 32, + kSM_80, + cubin_fmha_v2_fp16_128_32_sm80_cu_cubin, + cubin_fmha_v2_fp16_128_32_sm80_cu_cubin_len, + "fmha_v2_fp16_128_32_sm80_kernel", + 32768, + 128, + 0, + false, + false}, + {DATA_TYPE_FP16, + 128, + 32, + kSM_80, + cubin_fmha_v2_fp16_128_32_sm80_cu_cubin, + cubin_fmha_v2_fp16_128_32_sm80_cu_cubin_len, + "fmha_v2_fp16_128_32_sm80_kernel_nl", + 20480, + 128, + 32, + false, + false}, + {DATA_TYPE_FP16, + 192, + 32, + kSM_80, + cubin_fmha_v2_fp16_192_32_sm80_cu_cubin, + cubin_fmha_v2_fp16_192_32_sm80_cu_cubin_len, + "fmha_v2_fp16_192_32_sm80_kernel", + 16384, + 128, + 0, + false, + false}, + {DATA_TYPE_FP16, + 192, + 32, + kSM_80, + cubin_fmha_v2_fp16_192_32_sm80_cu_cubin, + cubin_fmha_v2_fp16_192_32_sm80_cu_cubin_len, + "fmha_v2_fp16_192_32_sm80_kernel_nl", + 16384, + 128, + 32, + false, + false}, + {DATA_TYPE_FP16, + 256, + 32, + kSM_80, + cubin_fmha_v2_fp16_256_32_sm80_cu_cubin, + cubin_fmha_v2_fp16_256_32_sm80_cu_cubin_len, + "fmha_v2_fp16_256_32_sm80_kernel", + 20480, + 128, + 0, + false, + false}, + {DATA_TYPE_FP16, + 256, + 32, + kSM_80, + cubin_fmha_v2_fp16_256_32_sm80_cu_cubin, + cubin_fmha_v2_fp16_256_32_sm80_cu_cubin_len, + "fmha_v2_fp16_256_32_sm80_kernel_nl", + 20480, + 128, + 32, + false, + false}, + {DATA_TYPE_FP16, + 384, + 32, + kSM_80, + cubin_fmha_v2_fp16_384_32_sm80_cu_cubin, + cubin_fmha_v2_fp16_384_32_sm80_cu_cubin_len, + "fmha_v2_fp16_384_32_sm80_kernel", + 32768, + 256, + 0, + false, + false}, + {DATA_TYPE_FP16, + 384, + 32, + kSM_80, + cubin_fmha_v2_fp16_384_32_sm80_cu_cubin, + cubin_fmha_v2_fp16_384_32_sm80_cu_cubin_len, + "fmha_v2_fp16_384_32_sm80_kernel_nl", + 32768, + 256, + 32, + false, + false}, + // GA10x: sm86 uses sm80 kernels {DATA_TYPE_FP16, 32, diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 86c1cb93e8..f7a19b988e 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -71,6 +71,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); @@ -89,10 +90,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SkipGroupNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); @@ -112,6 +116,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulBnb4); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulBnb4); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear); @@ -218,6 +226,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -236,10 +245,13 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -259,6 +271,10 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc b/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc index a38dfd34cc..274bc9a730 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc @@ -44,11 +44,6 @@ Status BiasAdd::ComputeInternal(OpKernelContext* context) const { "The input is expected to have 3 dimensions, got ", input_dims.size()); } - if (input_dims[2] != 320 && input_dims[2] != 640 && input_dims[2] != 1280) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Number of channels should be 320, 640 or 1280, got ", input_dims[2]); - } - const Tensor* bias = context->Input(1); const auto& bias_dims = bias->Shape().GetDims(); if (bias_dims.size() != 1) { diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu index 2983cc99e3..8e8068b5e5 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu @@ -42,6 +42,17 @@ __global__ void BiasAddKernel(T const* input, T const* bias, T const* residual, } } +template +__global__ void BiasAddLargeKernel( + int32_t const ld, const T* input, const T* bias, const T* residual, T* output) { + int32_t const offset = blockIdx.x * ld; + + for (int32_t i = threadIdx.x; i < ld; i += TPB) { + int32_t const base_offset = offset + i; + output[base_offset] = input[base_offset] + bias[i] + residual[base_offset]; + } +} + template __global__ void BiasAddKernel(float const*, float const*, float const*, float*); template __global__ void BiasAddKernel(float const*, float const*, float const*, float*); template __global__ void BiasAddKernel(float const*, float const*, float const*, float*); @@ -52,19 +63,19 @@ template __global__ void BiasAddKernel(half const*, half const* template void LaunchBiasAddKernel(cudaStream_t stream, int32_t grid_size, int32_t num_channels, T const* input, T const* bias, T const* residual, T* output) { - constexpr int32_t TPB = 320; // thread per block switch (num_channels) { case 320: - (BiasAddKernel)<<>>(input, bias, residual, output); + (BiasAddKernel)<<>>(input, bias, residual, output); break; case 640: - (BiasAddKernel)<<>>(input, bias, residual, output); + (BiasAddKernel)<<>>(input, bias, residual, output); break; case 1280: - (BiasAddKernel)<<>>(input, bias, residual, output); + (BiasAddKernel)<<>>(input, bias, residual, output); break; default: - ORT_NOT_IMPLEMENTED("Not implemented"); + BiasAddLargeKernel<<>>(num_channels, input, bias, residual, output); + break; } } diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc index 2b13cdbd80..cb02bd8541 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc @@ -39,9 +39,13 @@ Status BiasSplitGelu::ComputeInternal(OpKernelContext* context) const { "input is expected to have 3 dimensions, got ", input_dims.size()); } - if (input_dims[2] != 2560 && input_dims[2] != 5120 && input_dims[2] != 10240) { + if (input_dims[2] != 2560 && + input_dims[2] != 5120 && + input_dims[2] != 6144 && + input_dims[2] != 10240 && + input_dims[2] != 12288) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "hidden size should be 2560, 5120 or 10240, got ", input_dims[2]); + "hidden size should be 2560, 5120, 6144, 10240 or 12288, got ", input_dims[2]); } const Tensor* bias = context->Input(1); diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu index 19e05a9573..3ae9611d4d 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu @@ -65,6 +65,12 @@ void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t h case 5120: (biasSplitGeluKernel)<<>>(input, bias, output); break; + case 3072: + (biasSplitGeluKernel)<<>>(input, bias, output); + break; + case 6144: + (biasSplitGeluKernel)<<>>(input, bias, output); + break; default: ORT_NOT_IMPLEMENTED("Not implemented"); } @@ -73,9 +79,13 @@ void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t h template __global__ void biasSplitGeluKernel(float const*, float const*, float*); template __global__ void biasSplitGeluKernel(float const*, float const*, float*); template __global__ void biasSplitGeluKernel(float const*, float const*, float*); +template __global__ void biasSplitGeluKernel(float const*, float const*, float*); +template __global__ void biasSplitGeluKernel(float const*, float const*, float*); template __global__ void biasSplitGeluKernel(half const*, half const*, half*); template __global__ void biasSplitGeluKernel(half const*, half const*, half*); template __global__ void biasSplitGeluKernel(half const*, half const*, half*); +template __global__ void biasSplitGeluKernel(half const*, half const*, half*); +template __global__ void biasSplitGeluKernel(half const*, half const*, half*); template void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, float const* input, float const* bias, float* output); diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index 301b2e76b1..87e88ac31c 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -1,6 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/diffusion/group_norm.h" #include "contrib_ops/cuda/diffusion/group_norm_impl.h" @@ -15,14 +14,22 @@ ONNX_OPERATOR_KERNEL_EX( GroupNorm, kMSDomain, 1, kCudaExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); +ONNX_OPERATOR_KERNEL_EX( + SkipGroupNorm, kMSDomain, 1, kCudaExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); + using namespace ONNX_NAMESPACE; namespace { + template struct DispatchGroupNorm { Status operator()(cudaStream_t stream, Tensor* output, + Tensor* add_out, const Tensor* input, + const Tensor* skip, + const Tensor* bias, const Tensor* gamma, const Tensor* beta, void* workspace, @@ -32,12 +39,17 @@ struct DispatchGroupNorm { int height, int width, int num_groups, - bool use_swish_activation) { + bool use_swish_activation, + bool broadcast_skip, + int channels_per_block) { typedef typename ToCudaType::MappedType CudaT; return LaunchGroupNormKernel( stream, reinterpret_cast(output->MutableData()), + add_out == nullptr ? nullptr : reinterpret_cast(add_out->MutableData()), reinterpret_cast(input->Data()), + skip == nullptr ? nullptr : reinterpret_cast(skip->Data()), + bias == nullptr ? nullptr : reinterpret_cast(bias->Data()), gamma->Data(), beta->Data(), workspace, @@ -47,13 +59,21 @@ struct DispatchGroupNorm { height, width, num_groups, - use_swish_activation); + use_swish_activation, + broadcast_skip, + channels_per_block); } }; } // namespace GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { + has_skip_ = false; + const std::string& op_name = op_info.GetKernelDef().OpName(); + if (op_name == "SkipGroupNorm") { + has_skip_ = true; + } + epsilon_ = op_info.GetAttrOrDefault("epsilon", 1e-5f); ORT_ENFORCE(epsilon_ >= 0); @@ -68,6 +88,23 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { use_swish_activation_ = (activation == 1); channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); + + channels_per_block_ = 0; +} + +Status GroupNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr /*alloc*/, + bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { + is_packed = false; + + // Compute and cache cPerBlock using number of channels from gamma tensor shape. + if (input_idx == 1) { + auto gamma_shape = tensor.Shape(); + if (gamma_shape.NumDimensions() == 1) { + channels_per_block_ = GetChannelsPerBlock(static_cast(gamma_shape[0]), num_groups_); + } + } + + return Status::OK(); } Status GroupNorm::ComputeInternal(OpKernelContext* context) const { @@ -77,22 +114,38 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { Tensor* output = context->Output(0, input->Shape()); if (!channels_last_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "only the channels_last layout is supported"); } + if (!gamma->IsDataType() || !beta->IsDataType()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm only supports gamma and beta in float type"); + } + const auto& input_dims = input->Shape().GetDims(); if (input_dims.size() != 4) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input is expected to have 4 dimensions, got ", input_dims.size()); } + // Only support NHWC format right now. + int batch_size = static_cast(input_dims[0]); + int height = static_cast(input_dims[1]); + int width = static_cast(input_dims[2]); + int num_channels = static_cast(input_dims[3]); + + if (num_channels % num_groups_ != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "number of channels should be divisiable by num_groups"); + } + const auto& gamma_dims = gamma->Shape().GetDims(); if (gamma_dims.size() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "gamma is expected to have 1 dimension, got ", gamma_dims.size()); } - if (gamma_dims[0] != input_dims[3]) { + if (gamma_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of channels in gamma and input does not match"); } @@ -102,22 +155,11 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "beta is expected to have 1 dimension, got ", beta_dims.size()); } - if (beta_dims[0] != input_dims[3]) { + if (beta_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of channels in beta and input does not match"); } - // Input and output format is NHWC - int batch_size = static_cast(input_dims[0]); - int num_channels = static_cast(input_dims[3]); - int height = static_cast(input_dims[1]); - int width = static_cast(input_dims[2]); - - if (num_channels % num_groups_ != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "number of channels should be divisiable by num_groups"); - } - if (context->GetUseDeterministicCompute()) { static std::once_flag log_warning; std::call_once(log_warning, []() { @@ -125,17 +167,59 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { }); } - auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream()); + const Tensor* skip = nullptr; + const Tensor* bias = nullptr; + Tensor* add_out = nullptr; + + bool broadcast_skip = false; + if (has_skip_) { + skip = context->Input(3); + bias = context->Input(4); + add_out = context->Output(1, input->Shape()); + + if (bias != nullptr) { // Bias is optional + // If provided, bias has shape (C). + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "bias is expected to have 1 dimension, got ", bias_dims.size()); + } + if (bias_dims[0] != num_channels) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of channels in bias and input does not match"); + } + } + + // Check whether skip can be broadcasted to input shape. + if (skip->Shape() != input->Shape()) { + const auto& dims = skip->Shape().GetDims(); + // The shape of ship can be (N, C) or (N, 1, 1, C) for broadcast. + const bool b2 = (dims.size() == 2 && dims[0] == batch_size && dims[1] == num_channels); + const bool b4 = (dims.size() == 4 && dims[0] == batch_size && + dims[1] == 1 && dims[2] == 1 && dims[3] == num_channels); + broadcast_skip = b2 || b4; + if (!broadcast_skip) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "skip shape is expected to be (N, H, W, C) or (N, 1, 1, C) or (N, C)"); + } + } + } + + auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups_), + context->GetComputeStream()); utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); - return dispatcher.InvokeRet(Stream(context), output, input, gamma, beta, workspace.get(), + return dispatcher.InvokeRet(Stream(context), output, add_out, input, skip, bias, + gamma, beta, workspace.get(), epsilon_, batch_size, num_channels, height, width, num_groups_, - use_swish_activation_); + use_swish_activation_, + broadcast_skip, + channels_per_block_); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h index 52c006e6bd..b408b3c1ee 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h @@ -16,11 +16,16 @@ class GroupNorm final : public CudaKernel { GroupNorm(const OpKernelInfo& op_kernel_info); Status ComputeInternal(OpKernelContext* context) const override; + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) override; + private: - bool use_swish_activation_; + bool use_swish_activation_; // use SiLU (also known as Swish) activation after group normalization? float epsilon_; int num_groups_; bool channels_last_; + bool has_skip_; // true for SkipGroupNorm operator; false for GroupNorm + int channels_per_block_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index 01ba078b4b..48b161552c 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -16,18 +16,45 @@ */ // The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include #include #include #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/diffusion/group_norm_impl.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +using namespace onnxruntime::cuda; + namespace onnxruntime { namespace contrib { namespace cuda { -static inline int32_t divUp(int32_t m, int32_t n) { +namespace { + +// TODO: Similar to SkipLayerNorm kernel, read/write up to 8 channels at same time. +constexpr static int32_t CHANNELS_PER_THREAD = 2; + +constexpr static int kSizes[] = {128, 256, 320, 384, 512}; +constexpr static size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); +constexpr static int kMaxSize = kSizes[kNumOfSizes - 1]; + +int NextSize(int x) { + for (size_t i = 0; i < kNumOfSizes; ++i) { + if (x <= kSizes[i]) { + return kSizes[i]; + } + } + + return x; +} +} // namespace + +static inline int32_t DivUp(int32_t m, int32_t n) { return (m + n - 1) / n; } @@ -41,14 +68,14 @@ struct GroupSums { // The sum. float sum; // The sum of squares. - float sumSq; + float sum_sq; }; struct GroupSumsOp { inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { GroupSums dst; dst.sum = b.flag ? b.sum : (a.sum + b.sum); - dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq); + dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq); dst.flag = a.flag + b.flag; return dst; } @@ -56,54 +83,85 @@ struct GroupSumsOp { template struct GroupNormNHWCParams { - // The output buffer. Layout NHWC. + // The output buffer. Shape is (n, h, w, c). T* dst; - // The input buffer. Layout NHWC. + + // Optional output of element-wise add result of src, skip and bias. Shape is (n, h, w, c). + T* add_out; + + // The input buffer. Shape is (n, h, w, c). T const* src; + + // Optional input buffer for skip tensor. Shape is (n, h, w, c) or (n, 1, 1, c) or (n, c). + T const* skip; + + // Optional input buffer for bias tensor. Shape is (c). + T const* bias; + // The gamma scaling factor. float const* gamma; + // The beta term to add in GN. float const* beta; - // The temporary buffer to do the global parallel reduction. Size: - // BLOCKS_PER_BATCH x C x 2. - float* redBuffer; + + // The temporary buffer to do the global parallel reduction. Shape is (n, 2, g), where g is number of groups. + float* group_sum_buffer; // The number of instances in the batch. int32_t n; + // The height and width of each activation map. int32_t h; int32_t w; - // The number of channels. + + // Number of channels. int32_t c; - // The number of groups. + + // Number of groups. int32_t groups; - // Do we apply the Swish activation function? - bool withSwish; + + // Do we apply the SiLU activation function? + bool use_silu; // Precomputed values and parameters to control the execution of the kernels. - // The number of activations per instance (h * w) and the number of - // activations per block. + // Number of activations per instance (h * w) int32_t hw; - int32_t hwPerBlock; - // The number of channels per group and blocks per activation in the C - // dimension. - int32_t cPerBlock; - int32_t cPerGroup; + + // Number of activations per block + int32_t hw_per_block; + + // Number of channels per block in the C dimension. + int32_t channels_per_block; + + // Number of channels per group in the C dimension. + int32_t channels_per_group; // The precomputed stride between instances. int32_t hwc; - // The inverse of hwc in floats (to compute mean/var). - float invHWC; + // The inverse of hw*channels_per_group to compute mean of a group. + float inv_hw_channels_per_group; // The precomputed number of groups per block. - int32_t groupsPerBlock; + int32_t groups_per_block; + + // Number of threads per block + int32_t threads_per_block; + + // Epsilon to get stable variance in normalization. + float epsilon; + + // Whether skip need broadcast. True if shape of skip is (N, C) or (N, 1, 1, C); False otherwise. + bool broadcast_skip; + + // For SkipGroupNorm, it points to the intermediate result of adding skip and bias. + T* skip_workspace; }; template -inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sumSq); +inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq); template <> -inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sumSq) { +inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { // Fetch two channels per thread. __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); @@ -113,11 +171,11 @@ inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, fl sum += f2.x + f2.y; // Update the sum of squares. - sumSq += f2.x * f2.x + f2.y * f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; } template <> -inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sumSq) { +inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { // Fetch two channels per thread. float2 f2 = *reinterpret_cast(&src[offset]); @@ -125,119 +183,220 @@ inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, f sum += f2.x + f2.y; // Update the sum of squares. - sumSq += f2.x * f2.x + f2.y * f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +// Sum for SkipGroupNorm: add_out[offset] = src[offset] + skip[skip_offset] + bias[bias_offset] +template +inline __device__ void AddSkipBias(T* add_out, const T* src, const T* skip, const T* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq); + +template <> +inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); + __half2 b = *reinterpret_cast<__half2 const*>(&bias[bias_offset]); + h2 = h2 + b; + h2 = h2 + s; + + *reinterpret_cast<__half2*>(&add_out[offset]) = h2; + + float2 f2 = __half22float2(h2); + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + float2 f2 = *reinterpret_cast(&src[offset]); + float2 s = *reinterpret_cast(&skip[skip_offset]); + float2 b = *reinterpret_cast(&bias[bias_offset]); + f2.x += s.x + b.x; + f2.y += s.y + b.y; + + *reinterpret_cast(&add_out[offset]) = f2; + + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +// Sum for SkipGroupNorm without bias: add_out[offset] = src[offset] + skip[skip_offset] +template +inline __device__ void AddSkip(T* add_out, const T* src, const T* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq); + +template <> +inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); + h2 = h2 + s; + + *reinterpret_cast<__half2*>(&add_out[offset]) = h2; + + float2 f2 = __half22float2(h2); + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + float2 f2 = *reinterpret_cast(&src[offset]); + float2 s = *reinterpret_cast(&skip[skip_offset]); + f2.x += s.x; + f2.y += s.y; + *reinterpret_cast(&add_out[offset]) = f2; + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; } -template -__global__ void groupNormNHWCSumKernel(GroupNormNHWCParams params) { +template +__global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { // The object in charge of doing the sums for the different blocks. - typedef cub::BlockScan BlockScan; + typedef cub::BlockScan BlockScan; // Allocate shared memory for BlockScan. - __shared__ typename BlockScan::TempStorage tempStorage; - // Allocate shared memory for the groups. We could reduce the amount of shared - // memory reserved. - __shared__ float2 smem[tTHREADS_PER_BLOCK]; + __shared__ typename BlockScan::TempStorage temp_storage; + + // Allocate shared memory for the groups. We could reduce the amount of shared memory reserved. + __shared__ float2 smem[THREADS_PER_BLOCK]; // The instance in the batch. int32_t ni = blockIdx.z; - // The channel loaded by that thread (2 channels per thread for F16x2). - int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; + + // The channel loaded by that thread. + int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + + if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { + return; + } // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; + int32_t hw_begin = blockIdx.y * params.hw_per_block; // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); // The sums. float sum = 0.F; - float sumSq = 0.F; + float sum_sq = 0.F; // Iterate over the activations to compute the sums. - if (ci < params.c) { - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The offset. - int64_t offset = static_cast(ni) * params.hwc + static_cast(hwi) * params.c + ci; - UpdateSum(params.src, offset, sum, sumSq); + int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; + if (params.skip != nullptr) { + // SkipGroupNorm: skip is (n, h, w, c) or (n, 1, 1, c) or (n, c), bias is (c), and add_out is (n, h, w, c) + const int64_t bias_offset = static_cast(ci); + T* add_out = params.skip_workspace; + if (params.broadcast_skip) { + const int64_t skip_offset = static_cast(ni) * params.c + ci; + + if (params.bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkipBias(add_out, params.src, params.skip, params.bias, offset, skip_offset, bias_offset, sum, sum_sq); + } + } else { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkip(add_out, params.src, params.skip, offset, skip_offset, sum, sum_sq); + } + } + } else { + if (params.bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkipBias(add_out, params.src, params.skip, params.bias, offset, offset, bias_offset, sum, sum_sq); + } + } else { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkip(add_out, params.src, params.skip, offset, offset, sum, sum_sq); + } + } + } + } else { // GroupNorm + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + UpdateSum(params.src, offset, sum, sum_sq); } } - // The group that thread works on and the channel in the group (modulus). - int32_t gi = threadIdx.x * 2 / params.cPerGroup; - int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi; + // The group index relative to the first group within the same block. + int32_t gi = threadIdx.x * CHANNELS_PER_THREAD / params.channels_per_group; + // The channel in the group. + int32_t cj = ci % params.channels_per_group; // The data for the summations. - GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq}; + GroupSums inp{cj == 0 ? 1 : 0, sum, sum_sq}; - // Do the segmented scan. + // Do the segmented scan. InclusiveScan is not deterministic. GroupSums out; - BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); + BlockScan(temp_storage).InclusiveScan(inp, out, GroupSumsOp()); - // Store the results for the groups in shared memory (to produce coalesced - // stores later). - if (cj == params.cPerGroup - 2) { //2 channels per thread - smem[gi] = make_float2(out.sum, out.sumSq); + // Store the results for the groups in shared memory (to produce coalesced stores later). + // For each group, only the last thread of that group is picked to save sum to shared memory. + if (cj == params.channels_per_group - CHANNELS_PER_THREAD) { + smem[gi] = make_float2(out.sum, out.sum_sq); } // Make sure the data is in shared memory. __syncthreads(); - // The global group index. - int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x; - // Threads that have nothing left to do, exit. - if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) { + if (threadIdx.x >= params.groups_per_block) { return; } - // The first threads (those storing to global memory, load the values). - float2 sums = smem[threadIdx.x]; - - // Store to global memory. - atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x); - atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); + // The global group index. + // Use neighboring threads for coalesced write. + int32_t gj = blockIdx.x * params.groups_per_block + threadIdx.x; + + if (gj < params.groups) { + float2 sums = smem[threadIdx.x]; + const int index = (2 * ni) * params.groups + gj; + atomicAdd(¶ms.group_sum_buffer[index], sums.x); + atomicAdd(¶ms.group_sum_buffer[index + params.groups], sums.y); + } } template -void groupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { - // Make sure the values are as we expect. - ORT_ENFORCE(params.c % params.cPerBlock == 0 && params.hw % params.hwPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0); - +void GroupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params.c / params.cPerBlock; + grid.x = DivUp(params.c, params.channels_per_block); + // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); + grid.y = DivUp(params.hw, params.hw_per_block); + // The number of instances. grid.z = params.n; - switch (params.cPerBlock) { - case 320: - groupNormNHWCSumKernel<<>>(params); + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params.threads_per_block) { + case 256: + GroupNormNHWCSumKernel<<>>(params); break; - case 480: - groupNormNHWCSumKernel<<>>(params); + case 192: + GroupNormNHWCSumKernel<<>>(params); break; - case 256: - groupNormNHWCSumKernel<<>>(params); + case 160: + GroupNormNHWCSumKernel<<>>(params); break; case 128: - groupNormNHWCSumKernel<<>>(params); + GroupNormNHWCSumKernel<<>>(params); + break; + case 64: + GroupNormNHWCSumKernel<<>>(params); break; - default: - ORT_NOT_IMPLEMENTED("Not implemented"); } } template -__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float invStdDev, float2& gammaF2, float2& betaF2, bool swish); +__device__ void ComputeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu); template <> -__device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float invStdDev, - float2& gammaF2, float2& betaF2, bool swish) { +__device__ void ComputeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu) { // Fetch two channels per thread. __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); @@ -245,15 +404,15 @@ __device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, flo float2 f2 = __half22float2(h2); // Normalize the channels. - f2.x = (f2.x - mean) * invStdDev; - f2.y = (f2.y - mean) * invStdDev; + f2.x = (f2.x - mean) * inv_std_dev; + f2.y = (f2.y - mean) * inv_std_dev; // Scale by gamma and add beta. - f2.x = gammaF2.x * f2.x + betaF2.x; - f2.y = gammaF2.y * f2.y + betaF2.y; + f2.x = gamma_f2.x * f2.x + beta_f2.x; + f2.y = gamma_f2.y * f2.y + beta_f2.y; - // Apply Swish if needed. - if (swish) { + // Apply SiLU activation if needed. + if (silu) { f2.x = f2.x * sigmoid(f2.x); f2.y = f2.y * sigmoid(f2.y); } @@ -262,21 +421,21 @@ __device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, flo } template <> -__device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float invStdDev, - float2& gammaF2, float2& betaF2, bool swish) { +__device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu) { // Fetch two channels per thread. float2 f2 = *reinterpret_cast(&src[offset]); // Normalize the channels. - f2.x = (f2.x - mean) * invStdDev; - f2.y = (f2.y - mean) * invStdDev; + f2.x = (f2.x - mean) * inv_std_dev; + f2.y = (f2.y - mean) * inv_std_dev; // Scale by gamma and add beta. - f2.x = gammaF2.x * f2.x + betaF2.x; - f2.y = gammaF2.y * f2.y + betaF2.y; + f2.x = gamma_f2.x * f2.x + beta_f2.x; + f2.y = gamma_f2.y * f2.y + beta_f2.y; - // Apply Swish if needed. - if (swish) { + // Apply SiLU activation if needed. + if (silu) { f2.x = f2.x * sigmoid(f2.x); f2.y = f2.y * sigmoid(f2.y); } @@ -284,110 +443,142 @@ __device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, f *reinterpret_cast(&dst[offset]) = f2; } -template -__global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams params) { - // The channel loaded by that thread (2 channels per thread for F16x2). - int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; - if (ci >= params.c) { +template +__global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams params) { + // The channel loaded by that thread. + int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { return; } // The instance in the batch. int32_t ni = blockIdx.z; - // The group that thread works on and the channel in the group (modulus). - int32_t gi = ci / params.cPerGroup; + // The group that thread works on. + int32_t gi = ci / params.channels_per_group; // Load the sum and sum of squares for the group. - float sum = 0.F, sumSq = 0.F; + float sum = 0.F, sum_sq = 0.F; if (gi < params.groups) { - sum = params.redBuffer[(2 * ni + 0) * params.groups + gi]; - sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi]; + const int index = (2 * ni) * params.groups + gi; + sum = params.group_sum_buffer[index]; + sum_sq = params.group_sum_buffer[index + params.groups]; } - // Load gamma/beta. - float2 gammaF2 = *reinterpret_cast(¶ms.gamma[ci]); - float2 betaF2 = *reinterpret_cast(¶ms.beta[ci]); + // Load gamma/beta. Fetch two per thread. + float2 gamma_f2 = *reinterpret_cast(¶ms.gamma[ci]); + float2 beta_f2 = *reinterpret_cast(¶ms.beta[ci]); // Compute the mean. - float mean = sum * params.invHWC; + float mean = sum * params.inv_hw_channels_per_group; // Compute the variance. - float var = sumSq * params.invHWC - (mean * mean); + float var = sum_sq * params.inv_hw_channels_per_group - (mean * mean); // Compute the inverse of the stddev. - float invStdDev = var <= 0.F ? 1.F : rsqrtf(var); + float inv_std_dev = rsqrtf(var + params.epsilon); - // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; - // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t hw_begin = blockIdx.y * params.hw_per_block; + int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); - // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The src/dst offset. - int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; - - // Fetch two channels per thread. - computeGroupNorm(params.src, params.dst, offset, mean, invStdDev, gammaF2, betaF2, params.withSwish); + const T* input = (params.skip != nullptr) ? params.skip_workspace : params.src; + int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + ComputeGroupNorm(input, params.dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, params.use_silu); } } template -void groupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) { - // Make sure the dimensions are aligned with what we expect. - ORT_ENFORCE(params.c % params.cPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0); - +void GroupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params.c / params.cPerBlock; + grid.x = DivUp(params.c, params.channels_per_block); // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); + grid.y = DivUp(params.hw, params.hw_per_block); // The number of instances. grid.z = params.n; - switch (params.cPerBlock) { - case 320: - groupNormNHWCScaleKernel<<>>(params); + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params.threads_per_block) { + case 256: + GroupNormNHWCScaleKernel<<>>(params); break; - case 480: - groupNormNHWCScaleKernel<<>>(params); + case 192: + GroupNormNHWCScaleKernel<<>>(params); break; - case 256: - groupNormNHWCScaleKernel<<>>(params); + case 160: + GroupNormNHWCScaleKernel<<>>(params); break; case 128: - groupNormNHWCScaleKernel<<>>(params); + GroupNormNHWCScaleKernel<<>>(params); + break; + case 64: + GroupNormNHWCScaleKernel<<>>(params); break; - default: - ORT_NOT_IMPLEMENTED("Not implemented"); } } -int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { - int32_t maxDivisor = -1; +int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor) { + int32_t max_divisor = -1; for (int32_t i = 1; i <= std::sqrt(n); i++) { if (n % i == 0) { int32_t divisor1 = n / i; int32_t divisor2 = i; - if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) { - maxDivisor = divisor1; + if (divisor1 > max_divisor && divisor1 < max_allowed_divisor) { + max_divisor = divisor1; } - if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) { - maxDivisor = divisor2; + if (divisor2 > max_divisor && divisor2 < max_allowed_divisor) { + max_divisor = divisor2; } } } - return maxDivisor; + return max_divisor; +} + +// Find proper channels per block based on a cost function: The cost is number of channels corresponding to +// extra threads allocated but no channels assigned to them to work on. If cost is zero, every thread has +// work to do so it is ideal case. +int FindChannelsPerBlock(int num_channels, int channels_per_group) { + int min_cost = -1; + int best_candidate = -1; + for (size_t i = kNumOfSizes; i > 0; --i) { + if (kSizes[i - 1] < channels_per_group) { + break; + } + + int channels_per_block = kSizes[i - 1] / channels_per_group * channels_per_group; + int blocks = (num_channels + channels_per_block - 1) / channels_per_block; + int cost = blocks * kSizes[i - 1] - num_channels; + if (cost == 0) { + return channels_per_block; + } + + if (min_cost == -1 || cost < min_cost) { + min_cost = cost; + best_candidate = channels_per_block; + } + } + + return best_candidate; +} + +int GetChannelsPerBlock(int num_channels, int num_groups) { + int32_t channels_per_group = num_channels / num_groups; + int32_t channels_per_block = channels_per_group; + if (channels_per_group < kMaxSize / 2) { + channels_per_block = FindChannelsPerBlock(num_channels, channels_per_group); + } + return channels_per_block; } template Status LaunchGroupNormKernel( cudaStream_t stream, T* output, + T* add_out, const T* input, + const T* skip, + const T* bias, const float* gamma, const float* beta, void* workspace, @@ -397,79 +588,94 @@ Status LaunchGroupNormKernel( int height, int width, int num_groups, - bool use_swish_activation) { - if (batch_size > static_cast(kMaxGroupNormBatchSize)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only support batch_size <= 32. Got", batch_size); - } + bool use_silu, + bool broadcast_skip, + int channels_per_block) { + GroupNormNHWCParams params; - if (num_groups != static_cast(kGroupNormNumberOfGroups)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only num_groups=32 is supported. Got", num_groups); + int32_t channels_per_group = num_channels / num_groups; + // channels_per_block is computed in PrePack. + // If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here. + if (channels_per_block < channels_per_group) { + channels_per_block = GetChannelsPerBlock(num_channels, num_groups); } - GroupNormNHWCParams params; - int32_t cPerBlock = 320; - int32_t maxBlocksPerHW = 1024; - switch (num_channels) { - case 960: - case 1920: - cPerBlock = 480; - break; - case 512: - case 256: - cPerBlock = 256; - break; - case 128: - cPerBlock = 128; - break; - default: - cPerBlock = 320; + // TODO: Update the kernel to support CHANNELS_PER_THREAD==1 and other corner cases + if (channels_per_block % channels_per_group != 0 || + channels_per_block > kMaxSize || + (channels_per_group % CHANNELS_PER_THREAD != 0)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm in CUDA does not support the input: n=", batch_size, + " h=", height, + " w=", width, + " c=", num_channels, + " groups=", num_groups); } - params.withSwish = use_swish_activation; + params.use_silu = use_silu; params.dst = output; + params.add_out = add_out; params.src = input; + params.skip = skip; + params.bias = bias; params.gamma = gamma; params.beta = beta; - params.redBuffer = reinterpret_cast(workspace); + params.group_sum_buffer = reinterpret_cast(workspace); params.n = batch_size; params.h = height; params.w = width; params.c = num_channels; params.groups = num_groups; params.hw = params.h * params.w; - const int32_t blocksPerHW = findMaxDivisor(params.hw, maxBlocksPerHW); - params.hwPerBlock = divUp(params.hw, blocksPerHW); - params.cPerBlock = cPerBlock; - params.cPerGroup = params.c / params.groups; + + // This will allocate as many blocks as possible to partition HW. + // For Stable Diffusion, latent hw is 4K ~ 16K. This will allocate 1024 blocks, and each handles 4~16 hw. + // TODO: tune this logic to find proper blocks when hw is small. + constexpr int32_t max_blocks_per_hw = 1024; + const int32_t blocks_per_hw = FindMaxDivisor(params.hw, max_blocks_per_hw); + params.hw_per_block = DivUp(params.hw, blocks_per_hw); + + params.channels_per_block = channels_per_block; + params.channels_per_group = channels_per_group; params.hwc = params.hw * params.c; - params.invHWC = 1.F / (float)(params.hw * params.cPerGroup); - params.groupsPerBlock = cPerBlock / params.cPerGroup; + params.inv_hw_channels_per_group = 1.F / (float)(params.hw * params.channels_per_group); + params.groups_per_block = channels_per_block / params.channels_per_group; + params.epsilon = epsilon; + params.broadcast_skip = broadcast_skip; - DUMP_TENSOR_INIT(); - DUMP_TENSOR("input", input, batch_size, num_channels, height * width); - DUMP_TENSOR("gamma", gamma, 1, num_channels); - DUMP_TENSOR("beta", beta, 1, num_channels); - cudaMemsetAsync(params.redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), stream); - groupNormNHWCSum(params, stream); - DUMP_TENSOR("workspace", params.redBuffer, batch_size, num_groups, 2); + // Workspace for SkipGroupNorm to store intermediate results of src+skip+bias. + params.skip_workspace = (params.add_out != nullptr) ? params.add_out : params.dst; + + params.threads_per_block = NextSize(channels_per_block) / CHANNELS_PER_THREAD; + + CUDA_RETURN_IF_ERROR(cudaMemsetAsync( + params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream)); + + GroupNormNHWCSum(params, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); - groupNormNHWCScale(params, stream); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("workspace", params.group_sum_buffer, batch_size, 2, num_groups); + + GroupNormNHWCScale(params, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); - DUMP_TENSOR("output", output, batch_size, num_channels, height * width); + return Status::OK(); } -template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, - const half* input, const float* gamma, const float* beta, void* workspace, +template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, half* add_out, + const half* input, const half* skip, const half* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + int height, int width, int num_groups, bool silu, + bool broadcast_skip, int channels_per_block); -template Status LaunchGroupNormKernel(cudaStream_t stream, float* output, - const float* input, const float* gamma, const float* beta, void* workspace, +template Status LaunchGroupNormKernel(cudaStream_t stream, float* output, float* add_out, + const float* input, const float* skip, const float* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + int height, int width, int num_groups, bool silu, + bool broadcast_skip, int channels_per_block); } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h index c7e9245050..9532aeecb2 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h @@ -12,29 +12,33 @@ namespace onnxruntime { namespace contrib { namespace cuda { -constexpr size_t kMaxGroupNormBatchSize = 32; -constexpr size_t kGroupNormNumberOfGroups = 32; - -constexpr size_t GetGroupNormWorkspaceSizeInBytes() { +constexpr size_t GetGroupNormWorkspaceSizeInBytes(size_t batch_size, size_t num_groups) { // Two buffers for sum and squared sum - return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups; + return (sizeof(float) * 2) * batch_size * num_groups; } +int GetChannelsPerBlock(int num_channels, int num_groups); + template Status LaunchGroupNormKernel( cudaStream_t stream, - T* output, // normalized output tensor - const T* input, // input tensor - const float* gamma, // gamma (also known as weight or scale) - const float* beta, // beta (also known as bias) - void* workspace, // Work space - float epsilon, // epsilon used normalization - int batch_size, // N - int num_channels, // C - int height, // H - int width, // W - int num_groups, // number of groups - bool use_swish_activation // Whether there is Swish activation after group normalization + T* output, // normalized output tensor. Shape is (n, h, w, c) + T* add_out, // optional output tensor for element-wise sum of input + skip + bias. Shape is (n, h, w, c) + const T* input, // input tensor. Shape is (n, h, w, c) + const T* skip, // optional skip tensor. Shape is (n, h, w, c) + const T* bias, // optional bias tensor. Shape is (c) for SkipGroupNorm or (n, c) for BiasGroupNorm + const float* gamma, // gamma (also known as weight or scale). Shape is (c) + const float* beta, // beta (also known as bias). Shape is (c) + void* workspace, // Work space + float epsilon, // epsilon used normalization + int batch_size, // N + int num_channels, // C + int height, // H + int width, // W + int num_groups, // number of groups + bool use_silu, // Whether there is Sigmoid Linear Unit (SiLU) activation after group normalization + bool broadcast_skip, // Whether skip need broadcast. When skip has shape (n, c) or (n, 1, 1, c), it need broadcast. + int channels_per_block // Pre-computed channels per block. ); } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index f899a73ee0..705f2d49fe 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -174,8 +174,9 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { Tensor* present = context->Output(1, present_shape); void* fused_runner = nullptr; // TODO(tianleiwu): use fused kernel to speed up - bool use_fused_cross_attention = false; - bool use_memory_efficient_attention = false; + constexpr bool use_fused_cross_attention = false; + constexpr bool use_memory_efficient_attention = false; + constexpr bool use_flash_attention = false; size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, parameters.num_heads, @@ -185,6 +186,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { parameters.kv_sequence_length, parameters.total_sequence_length, fused_runner, + use_flash_attention, use_fused_cross_attention, use_memory_efficient_attention); @@ -193,27 +195,21 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; AttentionData data; data.gemm_buffer = reinterpret_cast(gemm_buffer.get()); - data.bias = nullptr; // bias has been added - data.query = nullptr; - data.key = nullptr; - data.value = nullptr; - data.mask_index = (nullptr == mask_index) ? nullptr : mask_index->Data(); - data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims(); - data.past = (nullptr == past_tensor) ? nullptr : reinterpret_cast(past_tensor->Data()); - data.past_key = nullptr; - data.past_value = nullptr; - data.relative_position_bias = nullptr; // add_qk is not supported in quantized attention + if (nullptr != mask_index) { + data.mask_index = mask_index->Data(); + data.mask_index_dims = mask_index->Shape().GetDims(); + } + + if (nullptr != past_tensor) { + data.past = reinterpret_cast(past_tensor->Data()); + } + data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); - data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData()); - data.present_key = nullptr; - data.present_value = nullptr; - data.fused_runner = fused_runner; - data.fused_cross_attention_kernel = nullptr; - data.use_memory_efficient_attention = use_memory_efficient_attention; - data.cumulated_sequence_length_q_cache = nullptr; - data.cumulated_sequence_length_kv_cache = nullptr; + if (nullptr != present) { + data.present = reinterpret_cast(present->MutableData()); + } return QkvToContext(GetDeviceProp(), cublas, context->GetComputeStream(), parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu new file mode 100644 index 0000000000..8c328d00b4 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu @@ -0,0 +1,131 @@ +// Modifications: scaling is moved from masked softmax to the gemm before that. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "dequantize_blockwise.cuh" + +using namespace onnxruntime::cuda; +using namespace cub; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +__device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, half scale, half zp, half* output) { + half2 scale_half2 = {scale, scale}; + half zp_adjust = -scale * __short2half_rn(zp); + half2 zp_adjust2 = {zp_adjust, zp_adjust}; + + alignas(16) half2 results[4]; + half v0 = __uint2half_rn(values_quant & 0xF); + half v1 = __uint2half_rn((values_quant >> 4) & 0xF); + results[0] = __halves2half2(v0, v1) * scale_half2 + zp_adjust2; + + half v2 = __uint2half_rn((values_quant >> 8) & 0xF); + half v3 = __uint2half_rn((values_quant >> 12) & 0xF); + results[1] = __halves2half2(v2, v3) * scale_half2 + zp_adjust2; + + half v4 = __uint2half_rn((values_quant >> 16) & 0xF); + half v5 = __uint2half_rn((values_quant >> 20) & 0xF); + results[2] = __halves2half2(v4, v5) * scale_half2 + zp_adjust2; + + half v6 = __uint2half_rn((values_quant >> 24) & 0xF); + half v7 = __uint2half_rn((values_quant >> 28) & 0xF); + results[3] = __halves2half2(v6, v7) * scale_half2 + zp_adjust2; + *(reinterpret_cast(output)) = *(reinterpret_cast(results)); +} + +__device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, float scale, float zp, float* output) { + float zp_adjust = -scale * zp; + output[0] = float(values_quant & 0xF) * scale + zp_adjust; + output[1] = float((values_quant >> 4) & 0xF) * scale + zp_adjust; + output[2] = float((values_quant >> 8) & 0xF) * scale + zp_adjust; + output[3] = float((values_quant >> 12) & 0xF) * scale + zp_adjust; + output[4] = float((values_quant >> 16) & 0xF) * scale + zp_adjust; + output[5] = float((values_quant >> 20) & 0xF) * scale + zp_adjust; + output[6] = float((values_quant >> 24) & 0xF) * scale + zp_adjust; + output[7] = float((values_quant >> 28) & 0xF) * scale + zp_adjust; +} + +template +__global__ void Dequantize4BitsKernel( + T* output, + const uint8_t* quant_data, + const T* scale_data, + const uint8_t* zero_points, + int block_size, + int blocks_per_threadblock, + int shift) { + int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift); + int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1)); + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); + T scale = *(scale_data + block_id); + uint8_t zp = 8; + if (zero_points) { + zp = (block_id & 0x01) ? (zero_points[block_id / 2] >> 4) : (zero_points[block_id / 2] & 0x0f); + } + + output = output + element_offset; + DequantizeEightElements(quant_value, scale, static_cast(zp), output); +} + +template +Status Dequantize4Bits( + T* output, + const uint8_t* quant_data, + const T* scales_data, + const uint8_t* zero_points, // shape: [N, (block_per_K + 1)/2] + int k, + int n, + int block_size, + cudaStream_t stream) { + // k is padded and equal to block_per_K * block_size + ORT_ENFORCE(k % block_size == 0, "k must be a multiplier of block_size"); + constexpr int element_per_thread = 8; + int blocks_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; + int blocks_per_K = k / block_size; + int blocks_per_grid = static_cast(CeilDiv(n * blocks_per_K, blocks_per_threadblock)); + int shift = static_cast(log2f(float(block_size))); + + Dequantize4BitsKernel<<>>( + output, + quant_data, + scales_data, + zero_points, + block_size, + blocks_per_threadblock, + shift); + + return Status::OK(); +} + +template Status Dequantize4Bits( + float* output, + const uint8_t* quant_data, + const float* scales_data, + const uint8_t* zero_points, + int k, + int n, + int block_size, + cudaStream_t stream); + +template Status Dequantize4Bits( + half* output, + const uint8_t* quant_data, + const half* scales_data, + const uint8_t* zero_points, + int k, + int n, + int block_size, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh new file mode 100644 index 0000000000..741ce1e735 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +template +Status Dequantize4Bits( + T* output, + const uint8_t* quant_data, + const T* scales_data, + const uint8_t* zero_points, + int k, + int n, + int block_size, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu new file mode 100644 index 0000000000..e58723f0b3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h" +#include "dequantize_blockwise_bnb4.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream) { + ORT_ENFORCE( + quant_type == FP4 || quant_type == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + T host_quant_map[16]; + switch (quant_type) { + case FP4: + for (int i = 0; i < 16; i++) host_quant_map[i] = static_cast(fp4_qaunt_map[i]); + break; + case NF4: + for (int i = 0; i < 16; i++) host_quant_map[i] = static_cast(nf4_qaunt_map[i]); + break; + } + CUDA_CALL_THROW(cudaMemcpyAsync(quant_map_buffer, host_quant_map, sizeof(T) * 16, cudaMemcpyHostToDevice, stream)); + + return Status::OK(); +} + +template Status SetBnbQuantMap(int quant_type, float* quant_map_buffer, cudaStream_t stream); + +template Status SetBnbQuantMap(int quant_type, half* quant_map_buffer, cudaStream_t stream); + +template +__global__ void kDequantizeBlockwise( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + const int block_size, + const int n) { + const int n_load = (gridDim.x * TILE_SIZE); + int valid_items_load = 0; + int valid_items_store = 0; + const int base_idx = (blockIdx.x * TILE_SIZE); + + T vals[NUM_PER_TH * 2]; + uint8_t qvals[NUM_PER_TH]; + T local_abs_max = T(0.0f); + + typedef cub::BlockLoad LoadChar; + typedef cub::BlockStore StoreT; + + __shared__ typename LoadChar::TempStorage loadchar; + __shared__ typename StoreT::TempStorage storet; + + for (unsigned int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) { + valid_items_load = (n + 1) / 2 - i > TILE_SIZE ? TILE_SIZE : (n + 1) / 2 - i; + valid_items_store = n - i * 2 > TILE_SIZE * 2 ? TILE_SIZE * 2 : n - i * 2; + + local_abs_max = __ldg(&absmax[(i + threadIdx.x * NUM_PER_TH) / (block_size)]); + + __syncthreads(); + LoadChar(loadchar).Load(&(quant_data[i]), qvals, valid_items_load, 128); + + #pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + vals[j * 2] = quant_map[qvals[j] >> 4] * local_abs_max; + vals[j * 2 + 1] = quant_map[qvals[j] & 0x0F] * local_abs_max; + #else + // half multiplication not supported + vals[j * 2] = static_cast(static_cast(quant_map[qvals[j] >> 4]) * static_cast(local_abs_max)); + vals[j * 2 + 1] = + static_cast(static_cast(quant_map[qvals[j] & 0x0F]) * static_cast(local_abs_max)); + #endif + } + + __syncthreads(); + StoreT(storet).Store(&(output[i * 2]), vals, valid_items_store); + } +} + +template +Status DequantizeBnb4( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + int block_size, + int numel, + cudaStream_t stream) { + int tile_size = 1024; + kDequantizeBlockwise<<<(numel + tile_size - 1) / tile_size, 64, 0, stream>>>( + quant_map, + output, + quant_data, + absmax, + block_size / 2, + numel); + + return Status::OK(); +} + +template Status DequantizeBnb4( + const float* quant_map, + float* output, + const uint8_t* quant_data, + const float* absmax, + int block_size, + int numel, + cudaStream_t stream); + +template Status DequantizeBnb4( + const half* quant_map, + half* output, + const uint8_t* quant_data, + const half *absmax, + int block_size, + int numel, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh new file mode 100644 index 0000000000..4aef3ab699 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream); + +template +Status DequantizeBnb4( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + int block_size, + int numel, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc new file mode 100644 index 0000000000..bd5b6e0a8a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h" +#include "matmul_bnb4.cuh" +#include "dequantize_blockwise_bnb4.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +using namespace onnxruntime::cuda; + +template +class MatMulBnb4 final : public CudaKernel { + public: + MatMulBnb4(const OpKernelInfo& info) : CudaKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("quant_type", &quant_type_)); + ORT_ENFORCE( + quant_type_ == FP4 || quant_type_ == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t quant_type_; +}; + +template +Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* a = ctx->Input(0); + const Tensor* b_quant = ctx->Input(1); + const Tensor* absmax = ctx->Input(2); + + const auto* a_data = a->Data(); + const uint8_t* b_quant_data = b_quant->Data(); + const auto* absmax_data = absmax->Data(); + + typedef typename ToCudaType::MappedType CudaT; + + // TODO: find a better way to create the quant_map without using a buffer + // don't want to use malloc directly so asking from the caller + // can create a __device__ static array for float but doesn't work for half + IAllocatorUniquePtr quant_map_buffer = GetScratchBuffer(16, ctx->GetComputeStream()); + auto* quant_map_buffer_data = quant_map_buffer.get(); + ORT_RETURN_IF_ERROR(SetBnbQuantMap( + SafeInt(quant_type_), + reinterpret_cast(quant_map_buffer_data), + static_cast(ctx->GetComputeStream()->GetHandle()))); + + constexpr bool transa = false; + constexpr bool transb = true; + MatMulComputeHelper helper; + TensorShape b_shape({N_, K_}); + ORT_RETURN_IF_ERROR( + helper.Compute(a->Shape(), b_shape, transa, transb)); + + Tensor* Y = ctx->Output(0, helper.OutputShape()); + // Bail out early if the output is going to be empty + if (Y->Shape().Size() == 0) return Status::OK(); + + bool is_4bit_done = TryMatMulBnb4( + reinterpret_cast(quant_map_buffer_data), + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + b_quant_data, + reinterpret_cast(absmax_data), + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle())); + + if (!is_4bit_done) { + IAllocatorUniquePtr b_dequant_ptr = GetScratchBuffer(N_ * K_, ctx->GetComputeStream()); + auto* b_dequant_data = b_dequant_ptr.get(); + ORT_RETURN_IF_ERROR(DequantizeBnb4( + reinterpret_cast(quant_map_buffer_data), + reinterpret_cast(b_dequant_data), + b_quant_data, + reinterpret_cast(absmax_data), + SafeInt(block_size_), + SafeInt(N_ * K_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + + const CudaT alpha = ToCudaType::FromFloat(1.f); + const CudaT zero = ToCudaType::FromFloat(0.f); + + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( + GetCublasHandle(ctx), + CUBLAS_OP_T, + CUBLAS_OP_N, + SafeInt(helper.N()), + SafeInt(helper.M()), + SafeInt(helper.K()), + &alpha, + reinterpret_cast(b_dequant_data), + SafeInt(K_), + reinterpret_cast(a_data), + helper.Lda(transa), + &zero, + reinterpret_cast(Y->MutableData()), + helper.Ldc(), + GetDeviceProp())); + } + + return Status::OK(); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu new file mode 100644 index 0000000000..1d9aa75ff3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu @@ -0,0 +1,192 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include +#include +#include +#include "matmul_bnb4.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define num_values_4bit 32 +template +__global__ void kgemm_4bit_inference_naive( + int M, + int N, + int K, + const T* __restrict__ A, + const uint8_t* B, + const T* absmax, + const T* datatype, + T* out, + int lda, + int ldb, + int ldc, + int block_size) { + // per threadblock: + // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] + // 4 warps -> 4 loads per iter + // 1x32 * 32x4 -> 1x4 outputs per thread block + typedef cub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage[THREADS / 32]; + + const int warp_idx = threadIdx.x / 32; + const int warp_lane = threadIdx.x % 32; + const int row_B = (THREADS / 32) * blockIdx.x + warp_idx; + const int num_values_8bit = num_values_4bit / 2; + float local_C = 0.0f; + + uint8_t local_B_4bit[num_values_8bit]; + T local_B[num_values_4bit / 4]; + T local_A[num_values_4bit / 4]; + __shared__ T quant_map[16]; + T local_absmax = T(0.0f); + + for (int i = threadIdx.x; i < 16; i++) quant_map[i] = T(datatype[i]); + __syncthreads(); + + // A: [1, K] + // B: [N, K] + for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += 32 * num_values_4bit) { + int inner_idx_halved = inner_idx / 2; + int offset_B = ldb * row_B; + int absidx = ((2 * offset_B) + inner_idx) / block_size; + local_absmax = __ldg(&(absmax[absidx])); + + if (row_B < N) { + if ((inner_idx_halved + num_values_8bit) < (K / 2)) { + // this is the most important for performance considerations + reinterpret_cast(local_B_4bit)[0] = + reinterpret_cast(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)]; + } else { + #pragma unroll + for (int j = 0; j < (num_values_8bit); j++) + if ((inner_idx_halved) + j < (K / 2)) + local_B_4bit[j] = B[offset_B + inner_idx_halved + j]; + else + local_B_4bit[j] = 0b01110111; + } + } else { + #pragma unroll + for (int j = 0; j < (num_values_8bit); j++) local_B_4bit[j] = 0b01110111; + } + + for (int i = 0; i < 4; i++) { + #pragma unroll + for (int k = 0; k < num_values_8bit / 4; k++) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax; + local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax; + #else + // half multiplication not supported + local_B[k * 2] = + static_cast(static_cast(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4]) * + static_cast(local_absmax)); + local_B[k * 2 + 1] = + static_cast(static_cast(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F]) * + static_cast(local_absmax)); + #endif + } + + if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) { + // this is also relatively important for performance + if (BITS == 16) { + reinterpret_cast(local_A)[0] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 4) + i]; + } else { + reinterpret_cast(local_A)[0] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0]; + reinterpret_cast(local_A)[1] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1]; + } + } else { + #pragma unroll + for (int k = 0; k < num_values_4bit / 4; k++) { + if (inner_idx + (i * num_values_4bit / 4) + k < K) + local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)]; + else + local_A[k] = T(0.0f); + } + } + + // accumulate in float; small performance hit for Ampere, but lower error for outputs + #pragma unroll + for (int k = 0; k < num_values_4bit / 4; k++) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + local_C += static_cast(local_A[k] * local_B[k]); + #else + // half multiplication not supported + local_C += static_cast(local_A[k]) * static_cast(local_B[k]); + #endif + } + } + } + + local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); + + if (row_B < N && warp_lane == 0) out[row_B] = T(local_C); +} + +template +bool TryMatMulBnb4( + const T* quant_map, + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream) { + if (k % block_size != 0 || m > 1) { + return false; + } + // supported block_sizes are [4096, 2048, 1024, 512, 256, 128, 64, 32] + if (block_size % 32 != 0 || block_size > 4096) { + return false; + } + + int lda = k; + int ldb = (k + 1) / 2; + int ldc = n; + int num_blocks = (n + 3) / 4; + + constexpr int bits = std::is_same_v ? 16 : 32; + kgemm_4bit_inference_naive<<>>( + m, n, k, a_data, b_data_quant, absmax, quant_map, output, lda, ldb, ldc, block_size); + + return true; +} + +template bool TryMatMulBnb4( + const float* quant_map, + float* output, + const float* a_data, + const uint8_t* b_data_quant, + const float* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream); + +template bool TryMatMulBnb4( + const half* quant_map, + half* output, + const half* a_data, + const uint8_t* b_data_quant, + const half* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh new file mode 100644 index 0000000000..743234282f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +bool TryMatMulBnb4( + const T* quant_map, + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc new file mode 100644 index 0000000000..14a8163fef --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -0,0 +1,149 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// This module define MatMulFp32Q4 operator, it is basically +// matmul float32 with right hand side being a 2-D matrix +// pre-packed and block-compacted into int4 +// + +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "matmul_nbits.cuh" +#include "dequantize_blockwise.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +using namespace onnxruntime::cuda; + +template +class MatMulNBits final : public CudaKernel { + public: + MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t nbits_; +}; + +template +Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* a = ctx->Input(0); + const Tensor* b = ctx->Input(1); + const Tensor* scales = ctx->Input(2); + const Tensor* zero_points = ctx->Input(3); + + const auto* a_data = a->Data(); + const uint8_t* blob_data = b->Data(); + const auto* scales_data = scales->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + + ORT_ENFORCE(nbits_ == 4, "only 4 bits is supported now"); + + typedef typename ToCudaType::MappedType CudaT; + + constexpr bool transa = false; + constexpr bool transb = true; + MatMulComputeHelper helper; + TensorShape b_shape({N_, K_}); + ORT_RETURN_IF_ERROR( + helper.Compute(a->Shape(), b_shape, transa, transb)); + + Tensor* Y = ctx->Output(0, helper.OutputShape()); + // Bail out early if the output is going to be empty + if (Y->Shape().Size() == 0) return Status::OK(); + + bool is_4bit_done = TryMatMul4Bits( + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + blob_data, + reinterpret_cast(scales_data), + zero_points_data, + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + SafeInt(GetDeviceProp().sharedMemPerBlock), + static_cast(ctx->GetComputeStream()->GetHandle())); + if (!is_4bit_done) { + int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; + IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); + auto* b_data = b_data_ptr.get(); + ORT_RETURN_IF_ERROR(Dequantize4Bits(reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + zero_points_data, + SafeInt(K_padded), + SafeInt(N_), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle()))); +#if 0 + cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle())); + T* b_data_cpu = new T[K_ * N_]; + cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost); + delete[] b_data_cpu; +#endif + + const CudaT alpha = ToCudaType::FromFloat(1.f); + const CudaT zero = ToCudaType::FromFloat(0.f); + + if (helper.OutputOffsets().size() == 1) { + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( + GetCublasHandle(ctx), + CUBLAS_OP_T, + CUBLAS_OP_N, + SafeInt(helper.N()), + SafeInt(helper.M()), + SafeInt(helper.K()), + &alpha, + reinterpret_cast(b_data), + SafeInt(K_padded), + reinterpret_cast(a_data), + helper.Lda(transa), + &zero, + reinterpret_cast(Y->MutableData()), + helper.Ldc(), + GetDeviceProp())); + } + } + + return Status::OK(); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBits); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBits); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu new file mode 100644 index 0000000000..4c3c345076 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -0,0 +1,217 @@ +// Modifications: scaling is moved from masked softmax to the gemm before that. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "matmul_nbits.cuh" + +using namespace onnxruntime::cuda; +using namespace cub; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a) { + half2 scale_half2 = {scale, scale}; + half zp_adjust = -scale * __short2half_rn(zp); + half2 zp_adjust2 = {zp_adjust, zp_adjust}; + uint4 vec_a = *(reinterpret_cast(a)); + + half2 element01 = __halves2half2(__uint2half_rn(values_quant & 0xF), __uint2half_rn((values_quant >> 4) & 0xF)); + half2 v0 = element01 * scale_half2 + zp_adjust2; + + half2 element23 = __halves2half2(__uint2half_rn((values_quant >> 8) & 0xF), __uint2half_rn((values_quant >> 12) & 0xF)); + half2 v1 = element23 * scale_half2 + zp_adjust2; + + half2 element45 = __halves2half2(__uint2half_rn((values_quant >> 16) & 0xF), __uint2half_rn((values_quant >> 20) & 0xF)); + half2 v2 = element45 * scale_half2 + zp_adjust2; + + half2 element67 = __halves2half2(__uint2half_rn((values_quant >> 24) & 0xF), __uint2half_rn((values_quant >> 28) & 0xF)); + half2 v3 = element67 * scale_half2 + zp_adjust2; + + v0 = v0 * (*(reinterpret_cast(&(vec_a.x)))); + v1 = v1 * (*(reinterpret_cast(&(vec_a.y)))); + v2 = v2 * (*(reinterpret_cast(&(vec_a.z)))) + v0; + v3 = v3 * (*(reinterpret_cast(&(vec_a.w)))) + v1; + v3 = v2 + v3; + return float(v3.x) + float(v3.y); +} + +__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a) { + float4 a_vec_0 = *(reinterpret_cast(a)); + float4 a_vec_1 = *(reinterpret_cast(a + 4)); + + float zp_adjust = -scale * zp; + float v0 = float(values_quant & 0xF) * scale + zp_adjust; + float v1 = float((values_quant >> 4) & 0xF) * scale + zp_adjust; + float v2 = float((values_quant >> 8) & 0xF) * scale + zp_adjust; + float v3 = float((values_quant >> 12) & 0xF) * scale + zp_adjust; + float v4 = float((values_quant >> 16) & 0xF) * scale + zp_adjust; + float v5 = float((values_quant >> 20) & 0xF) * scale + zp_adjust; + float v6 = float((values_quant >> 24) & 0xF) * scale + zp_adjust; + float v7 = float((values_quant >> 28) & 0xF) * scale + zp_adjust; + + v0 = v0 * a_vec_0.x; + v1 = v1 * a_vec_0.y; + v2 = v2 * a_vec_0.z; + v3 = v3 * a_vec_0.w; + v4 = v4 * a_vec_1.x + v0; + v5 = v5 * a_vec_1.y + v1; + v6 = v6 * a_vec_1.z + v2; + v7 = v7 * a_vec_1.w + v3; + return v4 + v5 + v6 + v7; +} + +constexpr int kColsPerThreadBlock = 8; +constexpr int kWarpSize = 32; + +// kernel for 4bits quantized gemv, i.e., computing A(1,K) x B(K, N) +// B(K, N) is quantized blockwise with 4bits and stored as [N, (K + block_size - 1)/block_size, blob] +// The thread block size is (kWarpSize, kColsPerThreadBlock) and grid size is (N/kColsPerThreadBlock, 1) +// Each thread block computes [1, K] x [kColsPerThreadBlock, (K + block_size - 1)/block_size, blob], +// i.e., computing kColsPerThreadBlock per block and a warp reduce (1, K) x (K) +template +__global__ void MatMulFloatInt4Kernel( + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int blocks_per_K) { + int n_block_id = blockIdx.x; + int m_id = blockIdx.y; + int lane_id = threadIdx.x; + int warp_id = threadIdx.y; + int n_id = n_block_id * kColsPerThreadBlock + warp_id; + int thread_id = warp_id * kWarpSize + lane_id; + constexpr int k_per_iter = 256; + int k_iter = k / k_per_iter; + + extern __shared__ char shared_buffer[]; + + // load scale to shared buffer + T* b_scale_vec = (T*)shared_buffer; + uint8_t* b_zp_vec = reinterpret_cast(b_scale_vec + kColsPerThreadBlock * blocks_per_K); + int offset = n_block_id * kColsPerThreadBlock * blocks_per_K; + for (int i = thread_id; i < kColsPerThreadBlock * blocks_per_K; i += kColsPerThreadBlock * kWarpSize) { + b_scale_vec[i] = scales_data[offset + i]; + } + for (int i = thread_id; i < kColsPerThreadBlock * blocks_per_K / 2; i += kColsPerThreadBlock * kWarpSize) { + b_zp_vec[i] = zero_points != nullptr ? zero_points[offset / 2 + i] : uint8_t(0x88); + } + __syncthreads(); + + a_data += m_id * k; + b_data_quant += n_id * blocks_per_K * (block_size / 2); + + float sum = 0.f; + int k_id = 0; + for (; k_id < (k & 0xffffff00); k_id += k_per_iter) { + uint32_t value = *(reinterpret_cast(b_data_quant + (k_id >> 1) + lane_id * 4)); + int32_t block_idx = warp_id * blocks_per_K + (k_id + lane_id * 8) / block_size; + T scale = b_scale_vec[block_idx]; + uint8_t zp = (block_idx & 0x01) ? (b_zp_vec[block_idx / 2] >> 4) : (b_zp_vec[block_idx / 2] & 0x0f); + sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3)); + } + + // handle reminder + if (k_id + lane_id * 8 < k) { + uint32_t value = *(reinterpret_cast(b_data_quant + k_iter * 128 + lane_id * 4)); + int32_t block_idx = warp_id * blocks_per_K + (k_id + lane_id * 8) / block_size; + T scale = b_scale_vec[block_idx]; + uint8_t zp = (block_idx & 0x01) ? (b_zp_vec[block_idx / 2] >> 4) : (b_zp_vec[block_idx / 2] & 0x0f); + sum += AccumulateEightElements(value, scale, zp, a_data + k_id + (lane_id << 3)); + } + + // warp reduction + for (int i = 16; i > 0; i = i / 2) { + sum += __shfl_down_sync(0xffffffff, sum, i); + } + + if (lane_id == 0) { + output[m_id * n + n_id] = sum; + } +} + +template +bool TryMatMul4Bits( + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + int shared_mem_per_block, + cudaStream_t stream) { + if (n % kColsPerThreadBlock != 0 || k % 8 != 0 || m > 1) { + return false; + } + dim3 blocks((n + kColsPerThreadBlock - 1) / kColsPerThreadBlock, m); + dim3 threads(kWarpSize, kColsPerThreadBlock); + int blocks_per_K = (k + block_size - 1) / block_size; + int blocks_per_thread_block = blocks_per_K * kColsPerThreadBlock; + int shared_mem_size = sizeof(T) * blocks_per_thread_block + blocks_per_thread_block / 2; + if (shared_mem_size > shared_mem_per_block) { + return false; + } + + if (16 == block_size) { + MatMulFloatInt4Kernel<<>>( + output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); + } else if (32 == block_size) { + MatMulFloatInt4Kernel<<>>( + output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); + } else if (64 == block_size) { + MatMulFloatInt4Kernel<<>>( + output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); + } else if (128 == block_size) { + MatMulFloatInt4Kernel<<>>( + output, a_data, b_data_quant, scales_data, zero_points, m, n, k, blocks_per_K); + } else { + ORT_THROW("block size ", block_size, " is not supported"); + } + + return true; +} + +template bool TryMatMul4Bits( + float* output, + const float* a_data, + const uint8_t* b_data_quant, + const float* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + int shared_mem_per_block, + cudaStream_t stream); + +template bool TryMatMul4Bits( + half* output, + const half* a_data, + const uint8_t* b_data_quant, + const half* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + int shared_mem_per_block, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh new file mode 100644 index 0000000000..9ccbe4c4d9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +bool TryMatMul4Bits( + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + int shared_mem_per_block, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index 07a8896210..67b52b466f 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -1315,6 +1315,67 @@ template void BufferExpansionKernelLauncher(const int32_t* input, int chunk_size, cudaStream_t stream); +// Support head_size up to 128 +constexpr unsigned int kTileSize = 32; +constexpr unsigned int kSeqTileSize = 16; + +__global__ void ReorderPastStatesKernel(float4* out_buffer, + const float4* in_buffer, + int batch_size, + int num_heads, + int max_length, + int chunked_head_size) { + __shared__ float4 tile[kSeqTileSize][kTileSize + 1]; + + const int b = blockIdx.z; + const int n = blockIdx.y; + const int s_base = blockIdx.x * kSeqTileSize; + const int s = s_base + threadIdx.y; + const int base_offset = (b * num_heads + n) * max_length * chunked_head_size; + + if (s < max_length) { + const int in_offset = base_offset + s * chunked_head_size + threadIdx.x; + tile[threadIdx.y][threadIdx.x] = in_buffer[in_offset]; + } + + __syncthreads(); + + const int tidx = threadIdx.x + threadIdx.y * chunked_head_size; + const int tidx_x = tidx % kSeqTileSize; + const int tidx_y = tidx / kSeqTileSize; + + const int s2 = s_base + tidx_x; + + if (s2 < max_length) { + const int out_offset = base_offset + tidx_y * max_length + s2; + out_buffer[out_offset] = tile[tidx_x][tidx_y]; + } +} + +void ReorderPastStatesKernelLauncher(void* out_buffer, + const void* in_buffer, + int batch_size, + int num_heads, + int max_length, + int head_size, + int chunk_size, + cudaStream_t stream) { + //[B, N, max_length, H2(head_size/chunk_size), equv_chunk_size] -> [B, N, H2(head_size/chunk_size), max_length, equv_chunk_size] + const int chunked_head_size = head_size / chunk_size; + const dim3 block(chunked_head_size, kSeqTileSize); + const dim3 grid((max_length + kSeqTileSize - 1) / kSeqTileSize, num_heads, batch_size); + if (chunk_size == 4 || chunk_size == 8) { + ReorderPastStatesKernel<<>>(reinterpret_cast(out_buffer), + reinterpret_cast(in_buffer), + batch_size, + num_heads, + max_length, + chunked_head_size); + } else { + ORT_THROW("ReorderPastStatesKernelLauncher only support float or half"); + } +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h index 8c52f6fd52..2c3662fb18 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h @@ -213,6 +213,14 @@ void BufferExpansionKernelLauncher(const T* input, int chunk_size, cudaStream_t stream); +void ReorderPastStatesKernelLauncher(void* out_buffer, + const void* in_buffer, + int batch_size, + int num_heads, + int max_length, + int head_size, + int chunk_size, + cudaStream_t stream); } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index e4de33499c..121cd05956 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -56,19 +56,23 @@ namespace GenerationCudaDeviceHelper { // It might be better to forcefully require the same type since cast node generates // extra overhead. Status ReorderPastState( - const void* cuda_device_prop, + const void*, Tensor& past_state, Tensor& past_state_staging, Stream* stream) { ORT_ENFORCE(stream); cudaStream_t cuda_stream = reinterpret_cast(stream->GetHandle()); - cublasHandle_t cublas_handle = static_cast(stream)->cublas_handle_; const auto& past_state_shape = past_state.Shape(); const auto& past_state_dims = past_state_shape.GetDims(); const bool packed_past = past_state_dims.size() == 5; + size_t batch_size = packed_past ? past_state_dims[1] : past_state_dims[0]; + size_t num_heads = packed_past ? past_state_dims[2] : past_state_dims[1]; + size_t max_length = packed_past ? past_state_dims[3] : past_state_dims[2]; + size_t head_size = packed_past ? past_state_dims[4] : past_state_dims[3]; + // Copy the 'K' values into the temp staging buffer size_t past_state_size = packed_past ? past_state.SizeInBytes() / 2 : past_state.SizeInBytes(); void* past_state_staging_buffer = past_state_staging.MutableDataRaw(); @@ -79,27 +83,16 @@ Status ReorderPastState( // [B, N, head_size / x, max_length, x], where x = 16 / sizeof(T) int64_t chunk_size = static_cast(16 / past_state.DataType()->Size()); - std::vector permutation_vector = {0, 1, 3, 2, 4}; - gsl::span permutation(permutation_vector.data(), 5); - - // "Fake" the shapes of the input and output tensors of the Transpose operation to suit our need - size_t offset = packed_past ? 1 : 0; - TensorShape transpose_input_shape_override = {past_state_shape[offset], - past_state_shape[offset + 1], - past_state_shape[offset + 2], - past_state_shape[offset + 3] / chunk_size, - chunk_size}; - - TensorShape transpose_output_shape_override = {past_state_shape[offset], past_state_shape[offset + 1], - past_state_shape[offset + 3] / chunk_size, past_state_shape[offset + 2], - chunk_size}; - - // TODO(hasesh): Explore perf tuning for this Transpose operation - return onnxruntime::cuda::Transpose::DoTranspose(*static_cast(cuda_device_prop), cuda_stream, - cublas_handle, permutation, - past_state_staging, past_state, - &transpose_input_shape_override, - &transpose_output_shape_override); + cuda::ReorderPastStatesKernelLauncher(past_state.MutableDataRaw(), + past_state_staging_buffer, + static_cast(batch_size), + static_cast(num_heads), + static_cast(max_length), + static_cast(head_size), + static_cast(chunk_size), + cuda_stream); + + return Status::OK(); } Status InitCacheIndir(Tensor& cache_indir, Stream* stream) { diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu index 9a150c9e6c..b0ed3ff822 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu @@ -30,6 +30,7 @@ limitations under the License. #include "contrib_ops/cpu/bert/attention_base.h" #include "contrib_ops/rocm/bert/attention_impl.h" #include "contrib_ops/rocm/bert/attention_softmax.h" +#include "contrib_ops/rocm/bert/decoder_attention_impl.h" using namespace onnxruntime::rocm; diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 19b2bc34ef..3164e8c211 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -28,34 +28,6 @@ size_t GetAttentionWorkspaceSize( int sequence_length, int past_sequence_length); -Status LaunchDecoderAttentionKernel( - const hipDeviceProp_t& prop, // Device Properties - RocmTuningContext* tuning_ctx, // context for tuning - Stream* stream, // ORT Stream - rocblas_handle& rocblas, // Rocblas handle - const size_t element_size, // Element size of input tensor - const int batch_size, // Batch size (B) - const int sequence_length, // Sequence length (S) - const int kv_sequence_length, // Key/Value/Cache sequence length - const int num_heads, // Number of attention heads (N) - const int head_size, // Hidden layer size per head (H) - const bool static_kv, // Whether cross attention or not - const bool use_past, // Whether use cache or not - const bool has_layer_state, // Whether output cache or not - const bool has_key_padding_mask, // Whether use key_padding_mask or not - const float mask_filter_value, // Mask filter value - const void* gemm_query_buffer, // Query buffer - const void* gemm_kv_buffer, // Key and value buffer - const bool* key_padding_mask, // Key padding mask - const void* key_cache, // Input key cache - const void* value_cache, // Input value cache - void* qkv_buffer, // Temporary buffer - void* workspace_buffer, // Temporary buffer - void* output, // Output tensor - void* new_key_cache, // New_key_cache tensor - void* new_value_cache // New_value_cache tensor -); - Status LaunchTransCtx(hipStream_t stream, const int sequence_length, const int batch_size, const int head_size, const int num_heads, const int max_threads_per_block, const bool reversed_bs, const float* input, float* output); diff --git a/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h new file mode 100644 index 0000000000..d71c6d8440 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/providers/rocm/shared_inc/rocm_utils.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +Status LaunchDecoderAttentionKernel( + const hipDeviceProp_t& prop, // Device Properties + RocmTuningContext* tuning_ctx, // context for tuning + Stream* stream, // ORT Stream + rocblas_handle& rocblas, // Rocblas handle + const size_t element_size, // Element size of input tensor + const int batch_size, // Batch size (B) + const int sequence_length, // Sequence length (S) + const int kv_sequence_length, // Key/Value/Cache sequence length + const int num_heads, // Number of attention heads (N) + const int head_size, // Hidden layer size per head (H) + const bool static_kv, // Whether cross attention or not + const bool use_past, // Whether use cache or not + const bool has_layer_state, // Whether output cache or not + const bool has_key_padding_mask, // Whether use key_padding_mask or not + const float mask_filter_value, // Mask filter value + const void* gemm_query_buffer, // Query buffer + const void* gemm_kv_buffer, // Key and value buffer + const bool* key_padding_mask, // Key padding mask + const void* key_cache, // Input key cache + const void* value_cache, // Input value cache + void* qkv_buffer, // Temporary buffer + void* workspace_buffer, // Temporary buffer + void* output, // Output tensor + void* new_key_cache, // New_key_cache tensor + void* new_value_cache // New_value_cache tensor +); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc index c665da89af..e82e15a304 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc @@ -72,6 +72,12 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) { channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); } +Status GroupNorm::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, + bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { + is_packed = false; + return Status::OK(); +} + Status GroupNorm::ComputeInternal(OpKernelContext* context) const { const Tensor* input = context->Input(0); const Tensor* gamma = context->Input(1); diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 300db24a98..0bf27fdf5e 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -1715,31 +1715,39 @@ class PlannerImpl { void PartitionIntoStreams(const logging::Logger& /*logger*/, const ExecutionProviders& /*execution_providers*/, const PathString& /*partition_config_file*/) { - stream_nodes_.push_back({}); - node_stream_map_.resize(SafeInt(graph_viewer_.MaxNodeIndex()) + 1); - for (auto node_index : graph_viewer_.GetNodesInTopologicalOrder()) { - stream_nodes_[0].push_back(node_index); - node_stream_map_[node_index] = 0; + if (graph_viewer_.NumberOfNodes() > 0) { + stream_nodes_.push_back({}); + node_stream_map_.resize(SafeInt(graph_viewer_.MaxNodeIndex()) + 1); + for (auto node_index : graph_viewer_.GetNodesInTopologicalOrder()) { + stream_nodes_[0].push_back(node_index); + node_stream_map_[node_index] = 0; + } + num_logic_streams_ = 1; } - num_logic_streams_ = 1; } Status BuildExecutionPlan(const ExecutionProviders& execution_providers) { // 1. create logic stream instance auto& execution_plan = plan_.execution_plan; - ORT_ENFORCE(num_logic_streams_ == 1 && !stream_nodes_[0].empty()); - execution_plan.reserve(1); - auto first_node_index = stream_nodes_[0][0]; - auto* node = graph_viewer_.GetNode(first_node_index); - onnxruntime::ProviderType exec_provider_name = node->GetExecutionProviderType(); - const IExecutionProvider* ep = execution_providers.Get(exec_provider_name); - ORT_ENFORCE(ep); - auto node_device_mem_location = ep->GetOrtDeviceByMemType(OrtMemType::OrtMemTypeDefault); - execution_plan.emplace_back(std::make_unique(node_device_mem_location)); - // 2. add steps to the execution plan - for (auto node_index : stream_nodes_[0]) { - execution_plan[0]->steps_.emplace_back(std::make_unique(node_index)); + + if (graph_viewer_.NumberOfNodes() > 0) { + ORT_ENFORCE(num_logic_streams_ == 1 && !stream_nodes_[0].empty()); + execution_plan.reserve(1); + auto first_node_index = stream_nodes_[0][0]; + auto* node = graph_viewer_.GetNode(first_node_index); + onnxruntime::ProviderType exec_provider_name = node->GetExecutionProviderType(); + const IExecutionProvider* ep = execution_providers.Get(exec_provider_name); + ORT_ENFORCE(ep); + auto node_device_mem_location = ep->GetOrtDeviceByMemType(OrtMemType::OrtMemTypeDefault); + execution_plan.emplace_back(std::make_unique(node_device_mem_location)); + // 2. add steps to the execution plan + for (auto node_index : stream_nodes_[0]) { + execution_plan[0]->steps_.emplace_back(std::make_unique(node_index)); + } + } else { + // graph with no nodes. e.g. subgraph of If might return the input as-is or a constant value from an initializer } + return Status::OK(); } diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index fc2f14263f..df3a7afebc 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -254,10 +254,11 @@ common::Status SaveInitializedTensors( auto initialized_tensors_to_allocate = id_to_initialized_tensor; for (int ort_value_index : initializer_allocation_order) { const auto entry = initialized_tensors_to_allocate.find(ort_value_index); + ORT_ENFORCE(entry != initialized_tensors_to_allocate.end(), + "OrtValue index: ", ort_value_index, " from initializer_allocation_order not found among initialized tensors"); if (!(utils::HasExternalData(*entry->second) && exec_plan.GetLocation(ort_value_index).Type() == OrtDevice::CPU)) { // can not trace string tensor - ORT_ENFORCE(entry != initialized_tensors_to_allocate.end() && - entry->second->data_type() != ONNX_NAMESPACE::TensorProto_DataType_STRING); + ORT_ENFORCE(entry->second->data_type() != ONNX_NAMESPACE::TensorProto_DataType_STRING, "Can not trace string tensor"); ORT_RETURN_IF_ERROR(planner.Trace(entry->first, entry->second)); } initialized_tensors_to_allocate.erase(entry); diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 5a42f5d34b..08ed811d9a 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -1492,7 +1492,7 @@ Status UnpackInitializerData(const onnx::TensorProto& initializer, if (initializer.data_location() == TensorProto_DataLocation_EXTERNAL) { ORT_RETURN_IF_ERROR(ReadExternalDataForTensor( initializer, - model_path.IsEmpty() ? nullptr : model_path.ParentPath().ToPathString().c_str(), + (model_path.IsEmpty() || model_path.ParentPath().IsEmpty()) ? nullptr : model_path.ParentPath().ToPathString().c_str(), unpacked_tensor)); return Status::OK(); } diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index 56f41154b7..ea6a629f87 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -223,7 +223,7 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType); -#ifdef ENABLE_TRAINING_CORE +#ifdef ENABLE_TRAINING common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context); #endif diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index e5956a575d..893776e778 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -171,10 +171,7 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c *output_shape.add_dim() = query_dims[1]; *output_shape.add_dim() = query_dims[2] * query_dims[4]; updateOutputShape(ctx, 0, output_shape); - return; - } - - if (hasInputShape(ctx, 2)) { + } else if (hasInputShape(ctx, 2)) { auto& value_shape = getInputShape(ctx, 2); auto& value_dims = value_shape.dim(); if (value_dims.size() != 3 && value_dims.size() != 4) { @@ -192,10 +189,7 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c ? (dmmha_packing ? value_dims[2] / 3 : value_dims[2]) : value_dims[1] * value_dims[3]; updateOutputShape(ctx, 0, output_shape); - return; - } - - if (hasInputShape(ctx, 1)) { + } else if (hasInputShape(ctx, 1)) { auto& key_shape = getInputShape(ctx, 1); if (key_shape.dim().size() == 5) { // packed KV ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput(ctx); @@ -217,7 +211,7 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c propagateElemTypeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); } else { if (sequence_length > 0 && past_dims[2].has_dim_value()) { - int64_t total_sequence_length = sequence_length + past_shape.dim(3).dim_value(); + int64_t total_sequence_length = sequence_length + past_dims[2].dim_value(); ONNX_NAMESPACE::TensorShapeProto present_shape; for (auto& dim : past_dims) { @@ -233,6 +227,59 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c } } +void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index) { + // Output 0 has shape (batch_size, sequence_length, hidden_size) + + // Q, K and V: + // Input 0 (query) has shape (batch_size, sequence_length, hidden_size) + // Input 1 (key) has shape (batch_size, kv_sequence_length, kv_hidden_size) + // Input 2 (value) has shape (batch_size, kv_sequence_length, kv_hidden_size) + + // Type inference + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); + + // Shape inference + if (hasInputShape(ctx, 0)) { + auto& query_shape = getInputShape(ctx, 0); + auto& query_dims = query_shape.dim(); + + if (query_dims.size() != 3) { + fail_shape_inference("Inputs 0 (query) shall be 3 dimensions"); + } + + if (hasInputShape(ctx, 2)) { + auto& value_shape = getInputShape(ctx, 2); + auto& value_dims = value_shape.dim(); + if (value_dims.size() != 3) { + fail_shape_inference("Inputs 2 (value) shall be 3 dimensions"); + } + + ONNX_NAMESPACE::TensorShapeProto output_shape; + *output_shape.add_dim() = query_dims[0]; + *output_shape.add_dim() = query_dims[1]; + *output_shape.add_dim() = query_dims[2]; + updateOutputShape(ctx, 0, output_shape); + return; + } else { + fail_shape_inference("Missing input 2 (value)"); + } + } + + if (ctx.getNumOutputs() > 1) { // has present output + if (hasInputShape(ctx, past_key_index)) { + auto& past_shape = getInputShape(ctx, past_key_index); + auto& past_dims = past_shape.dim(); + if (past_dims.size() != 4) { + fail_shape_inference("The past_key input shall be 4 dimensions"); + } + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, past_key_index, 1); + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); + ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, past_key_index, 1); + ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); + } + } +} + constexpr const char* Attention_ver1_doc = R"DOC( Multi-Head Attention that can be either unidirectional (like GPT-2) or bidirectional (like BERT). @@ -823,7 +870,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T") .Output(1, "present_key", - "past state for key with shape (batch_size, num_heads, total_sequence_length, head_size). " + "present state for key with shape (batch_size, num_heads, total_sequence_length, head_size). " "If past_present_share_buffer is set, " "its shape is (batch_size, num_heads, max_sequence_length, head_size), " "while effective_seq_length = (past_sequence_length + kv_sequence_length).", @@ -831,7 +878,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Output(2, "present_value", - "past state for value with shape (batch_size, num_heads, total_sequence_length, head_size). " + "present state for value with shape (batch_size, num_heads, total_sequence_length, head_size). " "If past_present_share_buffer is set, " "its shape is (batch_size, num_heads, max_sequence_length, head_size), " "while effective_seq_length = (past_sequence_length + kv_sequence_length).", @@ -889,7 +936,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(4, "key_padding_mask", - "Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)", + "Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), " + "or (batch_size, sequence_length, total_sequence_length)", "M", OpSchema::Optional) .Input(5, @@ -930,6 +978,80 @@ ONNX_MS_OPERATOR_SET_SCHEMA( MultiHeadAttentionTypeAndShapeInference(ctx, 6); })); +constexpr const char* GroupQueryAttention_ver1_doc = R"DOC( +Group Query Self/Cross Attention. + +Supports different number of heads for q and kv. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + GroupQueryAttention, 1, + OpSchema() + .SetDoc(GroupQueryAttention_ver1_doc) + .Attr("num_heads", "Number of attention heads for q", AttributeProto::INT) + .Attr("kv_num_heads", "Number of attention heads for k and v", AttributeProto::INT) + .Attr("scale", + "Custom scale will be used if specified. Default value is 1/sqrt(head_size)", + AttributeProto::FLOAT, + OPTIONAL_VALUE) + // .Attr("left_padding_last_token", + // "Copy last token to last index of buffer. Default is 0; 1 when true.", + // AttributeProto::INT, + // OPTIONAL_VALUE) + .Input(0, + "query", + "Query with shape (batch_size, sequence_length, hidden_size)", + "T") + .Input(1, + "key", + "Key with shape (batch_size, kv_sequence_length, kv_hidden_size) ", + "T") + .Input(2, + "value", + "Value with shape (batch_size, kv_sequence_length, kv_hidden_size)", + "T") + .Input(3, + "past_key", + "past state key with support for format BNSH. When past_key uses same tensor as present_key" + "(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.", + "T", + OpSchema::Optional) + .Input(4, + "past_value", + "past state value with support for format BNSH. When past_value uses same tensor as present_value" + "(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.", + "T", + OpSchema::Optional) + .Input(5, + "seqlens_k", + "1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.", + "M") + .Input(6, + "total_sequence_length", + "Scalar tensor of total sequence length (past + new).", + "M") + .Output(0, + "output", + "3D output tensor with shape (batch_size, sequence_length, hidden_size)", + "T") + .Output(1, + "present_key", + "present state key with support for format BNSH. When past_key uses same tensor as present_key" + "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" + "kv_sequence_length.", + "T") + .Output(2, + "present_value", + "present state value with support for format BNSH. When past_value uses same tensor as present_value" + "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" + "kv_sequence_length.", + "T") + .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output to float tensors.") + .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to int tensor.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + GroupQueryAttentionTypeAndShapeInference(ctx, 3); + })); + constexpr const char* Longformer_Attention_doc = R"DOC( Longformer Self Attention with a local context and a global context. Tokens attend locally: Each token attends to its W previous tokens and W succeeding tokens with W being the window length. A selected few tokens @@ -994,6 +1116,49 @@ ONNX_MS_OPERATOR_SET_SCHEMA( DecoderAttentionTypeAndShapeInference(ctx); })); +constexpr const char* RotaryEmbedding_ver1_doc = R"DOC( +RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices +that are multiplied to query and key before the inner product of query and key is taken. +)DOC"; +ONNX_MS_OPERATOR_SET_SCHEMA( + RotaryEmbedding, 1, + OpSchema() + .SetDoc(RotaryEmbedding_ver1_doc) + .Attr("scale", + "Custom scale will be used if specified. Default value is 1.0", + AttributeProto::FLOAT, + OPTIONAL_VALUE) + .Attr("interleaved", + "Rotate using interleaved pattern. Default value is 0 (False).", + AttributeProto::INT, + OPTIONAL_VALUE) + .Input(0, + "input", + "3D tensor with shape (batch_size, sequence_length, hidden_size)", + "T") + .Input(1, + "position_ids", + "1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)", + "M") + .Input(2, + "cos_cache", + "2D tensor with shape (max_sequence_length, head_size / 2).", + "T") + .Input(3, + "sin_cache", + "2D tensor with shape (max_sequence_length, head_size / 2).", + "T") + .Output(0, + "output", + "3D tensor with shape (batch_size, sequence_length, hidden_size)", + "T") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") + .TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + propagateShapeFromInputToOutput(ctx, 0, 0); + })); + constexpr const char* EmbedLayerNormalization_ver1_doc = R"DOC( EmbedLayerNormalization is the fusion of embedding layer in BERT model, with optional mask processing. The embedding layer takes input_ids (word IDs) and segment_ids (sentence IDs) to look up word_embedding, position_embedding, diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index a79203a94a..1db7638677 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2384,6 +2384,35 @@ ONNX_MS_OPERATOR_SET_SCHEMA(CropAndResize, 1, a fixed size = [crop_height, crop_width]. The result is a 4-D tensor [num_boxes, crop_height, crop_width, depth]. The resizing is corner aligned.)DOC")); +static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, + int64_t K, + int64_t N) { + int input_a_idx = 0; + if (!hasInputShape(ctx, input_a_idx)) { + return; + } + + const auto& a_shape = ctx.getInputType(input_a_idx)->tensor_type().shape(); + if (a_shape.dim_size() == 0) { + fail_shape_inference("Input tensors of wrong rank (0)."); + } + + // TODO: check B shape + + const auto& dim_last = a_shape.dim(a_shape.dim_size() - 1); + if (dim_last.has_dim_value() && dim_last.dim_value() != K) { + fail_shape_inference("Incompatible dimensions for matrix multiplication"); + } + + ONNX_NAMESPACE::TensorShapeProto resultShape; + for (int i = 0; i < a_shape.dim_size() - 1; ++i) { + *resultShape.add_dim() = a_shape.dim(i); + } + resultShape.add_dim()->set_dim_value(N); + + *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape() = resultShape; +} + void RegisterContribSchemas() { ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(AttnLSTM, RegisterAttnLSTMContribOpSchema); ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(Range, RegisterRangeOpSchema); @@ -2895,6 +2924,90 @@ This op functions in much the same was as Dropout-11 and Dropout-13 do, execpt t } }); + static const char* MatMulNBits_ver1_doc = R"DOC( +MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7).It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: + 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. + 2. Input B is quantized with x bits which is specified by attribute 'bits'. It is quantized blockwisely along dimension 0 (e.g. column) with block size specified by attribute block_size. + And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. + 3. Input B's scale and zero point are specified by input scales and zero_points. + +Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: +- n_blocks_per_col = (K + block_size - 1) / block_size +- blob_size = block_size / 8 * bits + + For a block blob. It is stored in format: + struct Blob { + uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization + uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization + uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization + } + +Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] +Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: + - [(N * n_blocks_per_col + 1) / 2] if bits <=4 + - [N * n_blocks_per_col] if bits > 4 + +)DOC"; + + ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(MatMulNBits_ver1_doc) + .Attr("K", "size of each input feature", AttributeProto::INT) + .Attr("N", "size of each output feature", AttributeProto::INT) + .Attr("bits", "number of bits used for weight quantization (default 4)", AttributeProto::INT) + .Attr("block_size", "number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT) + .Input(0, "A", "The input tensor, not quantized", "T1") + .Input(1, "B", "1-dimensional data blob", "T2") + .Input(2, "scales", "quantization scale", "T1") + .Input(3, "zero_points", "quantization zero points", "T2", OpSchema::Optional) + .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") + .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") + .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + // Type inference + propagateElemTypeFromInputToOutput(ctx, 0, 0); + // Shape inference + int64_t in_features = getAttribute(ctx, "K", -1); + int64_t out_features = getAttribute(ctx, "N", -1); + MatmulWithQuantWeightShapeInference(ctx, in_features, out_features); + }); + + static const char* MatMulBnb4_ver1_doc = R"DOC( +MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: + 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. + 2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'. + And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. + 3. Input B's quantization constants or scales are specified by input 'absmax'. + +Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. +Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + +)DOC"; + + ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulBnb4) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(MatMulBnb4_ver1_doc) + .Attr("K", "size of each input feature", AttributeProto::INT) + .Attr("N", "size of each output feature", AttributeProto::INT) + .Attr("block_size", "number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT) + .Attr("quant_type", "quantization data type. 0 for FP4, 1 for NF4.", AttributeProto::INT) + .Input(0, "A", "The input tensor, not quantized", "T1") + .Input(1, "B", "1-dimensional quantized data for weight", "T2") + .Input(2, "absmax", "quantization constants", "T1") + .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") + .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") + .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + // Type inference + propagateElemTypeFromInputToOutput(ctx, 0, 0); + // Shape inference + int64_t in_features = getAttribute(ctx, "K", -1); + int64_t out_features = getAttribute(ctx, "N", -1); + MatmulWithQuantWeightShapeInference(ctx, in_features, out_features); + }); + #ifdef ENABLE_ATEN ONNX_CONTRIB_OPERATOR_SCHEMA(ATen) .SetDomain(kPytorchAtenDomain) @@ -2913,6 +3026,24 @@ This op functions in much the same was as Dropout-11 and Dropout-13 do, execpt t "Allow inputs and outputs to be any kind of tensor."); #endif + ONNX_CONTRIB_OPERATOR_SCHEMA(QuadricCustomOp) + .SetDomain(kQuadricDomain) + .SinceVersion(1) + .SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL) + .SetDoc("QuadricCustomOp") + .Input(0, "inputs", "QuadricCustomOp inputs.", "T", OpSchema::Variadic, + /*is_homogeneous*/ false, + /*min_arity*/ 1) + .Output(0, "outputs", "QuadricCustomOp outputs.", "T", OpSchema::Variadic, + /*is_homogeneous*/ false, + /*min_arity*/ 1) + .AllowUncheckedAttributes() + .Attr("ccl_func_name", "Name of CCL function.", AttributeProto::STRING) + .Attr("sub_graph", "Replaced sub-graph.", AttributeProto::GRAPH) + .Attr("element_wise", "True (1) if only element-wise ops, False (0) otherwise", AttributeProto::INT, true) + .TypeConstraint("T", OpSchema::all_tensor_types_with_bfloat(), + "Allow inputs and outputs to be any kind of tensor."); + #ifdef ENABLE_TRAINING_OPS // Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or // 2). this is needed by inference for other purpose. diff --git a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc index c2f5edaa61..f81c3b8e01 100644 --- a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc @@ -42,7 +42,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "The number of groups of channels. It should be a divisor of the number of channels C", AttributeProto::INT) .Attr("activation", - "Activation after group normalization: 0 for None, 1 for Swish", + "Activation after group normalization: 0 for None, 1 for SiLU", AttributeProto::INT) .Attr("channels_last", "1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.", @@ -68,6 +68,85 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .TypeConstraint("M", {"tensor(float16)", "tensor(float)"}, "Constrain gamma and beta to float tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); +constexpr const char* SkipGroupNorm_ver1_doc = R"DOC( +This operator element-wise adds x, skip and bias, then apply group normalization and optional activation. + +This operator transforms input according to + s = x + skip + bias + y = gamma * (s - mean) / sqrt(variance + epsilon) + beta + +The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. +The num_channels must be divisible by num_groups. +The mean and standard-deviation of s are calculated separately over the each group. +The weight and bias are per-channel affine transform parameter vectors of size num_channels. + +The activation attribute can be used to enable activation after group normalization. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + SkipGroupNorm, 1, + OpSchema() + .SetDoc(SkipGroupNorm_ver1_doc) + .Attr("epsilon", "The epsilon value to use to avoid division by zero", + AttributeProto::FLOAT, static_cast(1e-5)) + .Attr("groups", + "The number of groups of channels. It should be a divisor of the number of channels C", + AttributeProto::INT) + .Attr("activation", + "Activation after group normalization: 0 for None, 1 for SiLU", + AttributeProto::INT) + .Attr("channels_last", + "1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.", + AttributeProto::INT, + static_cast(1)) + .Input(0, + "X", + "Input data tensor. Dimensions are (N x H x W x C) when channels_last is 1 " + " or (N x C x H x W) otherwise, where N is the batch size, C is the number of channels," + " and H and W are the height and width of the data", + "T") + .Input(1, + "gamma", + "1D gamma tensor for normalization with shape (C), where C is number of channels", + "M") + .Input(2, + "beta", + "1D beta tensor for normalization with shape (C), where C is number of channels", + "M") + .Input(3, + "skip", + "4D or 2D skip tensor. The shape can be (N x H x W x C) or (N x 1 x 1 x C) or (N x C)", + "T") + .Input(4, + "bias", + "1D bias tensor. Dimensions are (C), where C is number of channels", + "T", + OpSchema::Optional) + .Output(0, + "Y", + "The output tensor of the same shape as X", + "T") + .Output(1, + "S", + "The element-wise sum of input x, skip and bias tensors. It has the same shape as X", + "T", + OpSchema::Optional) + .TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X, skip, bias and output Y, S types to float tensors.") + .TypeConstraint("M", {"tensor(float16)", "tensor(float)"}, "Constrain gamma and beta to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (ctx.getNumOutputs() > 1) { + propagateElemTypeFromInputToOutput(ctx, 0, 1); + } + + if (hasInputShape(ctx, 0)) { + propagateShapeFromInputToOutput(ctx, 0, 0); + if (ctx.getNumOutputs() > 1) { + propagateShapeFromInputToOutput(ctx, 0, 1); + } + } + })); + constexpr const char* BiasSplitGelu_ver1_doc = R"DOC( A fusion used in diffusion model that after adding bias, hidden state is sliced into two tensors of same size, then left tensor multiplies the Gelu activation result of right tensor. diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index 3c31997286..7427469df0 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -27,6 +27,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QEmbedLayerNormalization class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QGemm); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearAdd); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearConcat); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearConvTranspose); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearWhere); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearLeakyRelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearMul); @@ -83,6 +84,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulFpQ4); #endif class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MaxpoolWithMask); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MultiHeadAttention); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GroupQueryAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad); @@ -93,8 +95,10 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GatedRelativePositionBia class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RotaryEmbedding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipGroupNorm); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipSimplifiedLayerNormalization); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseToDenseMatMul); @@ -113,6 +117,7 @@ class OpSet_Microsoft_ver1 { static void ForEachSchema(std::function fn) { fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); @@ -182,6 +187,7 @@ class OpSet_Microsoft_ver1 { #endif fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); @@ -194,8 +200,10 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/core/graph/function_utils.cc b/onnxruntime/core/graph/function_utils.cc index 4b79001944..aa0727e375 100644 --- a/onnxruntime/core/graph/function_utils.cc +++ b/onnxruntime/core/graph/function_utils.cc @@ -344,10 +344,15 @@ std::unique_ptr CreateSchema(const std::string& functi std::unordered_map map_copy(model_local_functions.begin(), model_local_functions.end()); std::unordered_map empty_map; - ONNX_NAMESPACE::shape_inference::SymbolTableImpl symbolTable; + + // https://github.com/microsoft/onnxruntime/issues/17061 + // We are passing a nullptr for the symbol table, because symbol table must be global + // for all the shape inferencing to work correctly. Otherwise, unrelated shapes get + // the same symbolic shapes and are marked for memory re-use. This is a Temp fix. + constexpr ONNX_NAMESPACE::shape_inference::SymbolTableImpl* symbolTable = nullptr; ONNX_NAMESPACE::shape_inference::InferShapeForFunctionNode(*onnx_func_proto, func_domain_to_version, schema_registry, ctx, options, map_copy, - &symbolTable, &empty_map); + symbolTable, &empty_map); }); op_schema->Finalize(); diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 86b7450a7c..32cc69d0b8 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -451,12 +451,16 @@ Return Value: #if defined(_WIN32) HasDotProductInstructions = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); -#elif !defined(__APPLE__) // The next few lines result in an EXC_BAD_INSTRUCTION runtime error on a M1 Mac so we - // disable it there. - uint64_t isar0_el1; - asm("mrs %[reg], ID_AA64ISAR0_EL1\n" : [reg] "=r"(isar0_el1) : :); - HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u; #else + // Use the cpuinfo value which is read from sysctl and has some additional special cases. + // https://github.com/pytorch/cpuinfo/blob/959002f82d7962a473d8bf301845f2af720e0aa4/src/arm/mach/init.c#L369-L379 + // Do NOT use ID_AA64ISAR0_EL1. It causes illegal instruction errors on Mac M1 and ARMv8-A chips + // as well as failing on other ARM chips as it is an EL1 level register that requires extra + // privileges to read. + // + // uint64_t isar0_el1; + // asm("mrs %[reg], ID_AA64ISAR0_EL1\n" : [reg] "=r"(isar0_el1) : :); + // HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u; HasDotProductInstructions = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot(); #endif diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.cc index a3ac431205..dd38ee9b07 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.cc @@ -462,6 +462,27 @@ bool LayerNormalizationGatherActor::PreCheck(const Graph& /* graph */, return true; } +bool LayerNormalizationGatherActor::PostProcess(Graph& /*graph*/, Node& current_node, + const SliceInfo& info_without_node, + const logging::Logger& /*logger*/, + const std::unordered_map& /*propagate_input_indices*/, + const std::unordered_map>& + /*all_input_cmp_rets*/, + const std::unordered_map& /*new_gather_infos*/) { + // Update LayerNormalization's axis attribute if it is scalar slice. + if (info_without_node.is_scalar_slice) { + auto axis = static_cast(current_node.GetAttributes().at("axis").i()); + auto original_ln_input_rank = info_without_node.input_rank; + axis = axis < 0 ? axis + original_ln_input_rank : axis; + auto new_axis = axis - 1; + + auto& attributes = current_node.GetMutableAttributes(); + attributes["axis"] = ONNX_NAMESPACE::MakeAttribute("axis", static_cast(new_axis)); + } + + return true; +} + bool SoftmaxGatherActor::PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info, const logging::Logger& logger, std::unordered_map& propagate_input_indices, @@ -479,6 +500,28 @@ bool SoftmaxGatherActor::PreCheck(const Graph& graph, const Node& current_node, propagate_input_indices, all_input_cmp_rets, shape_update_func); } +bool SoftmaxGatherActor::PostProcess(Graph& graph, Node& current_node, const SliceInfo& info_without_node, + const logging::Logger& logger, + const std::unordered_map& propagate_input_indices, + const std::unordered_map>& all_input_cmp_rets, + const std::unordered_map& new_gather_infos) { + SimplePointwiseGatherActor::PostProcess(graph, current_node, info_without_node, logger, + propagate_input_indices, all_input_cmp_rets, new_gather_infos); + + // Update Softmax's axis attribute if it is scalar slice. + if (info_without_node.is_scalar_slice) { + auto axis = static_cast(current_node.GetAttributes().at("axis").i()); + auto original_ln_input_rank = info_without_node.input_rank; + axis = axis < 0 ? axis + original_ln_input_rank : axis; + auto new_axis = axis - 1; + + auto& attributes = current_node.GetMutableAttributes(); + attributes["axis"] = ONNX_NAMESPACE::MakeAttribute("axis", static_cast(new_axis)); + } + + return true; +} + bool ReshapeGatherActor::PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info, const logging::Logger& logger, std::unordered_map& propagate_input_indices, @@ -566,6 +609,11 @@ bool ReshapeGatherActor::PreCheck(const Graph& graph, const Node& current_node, return true; } + LOG_DEBUG_INFO(logger, "Skip handle the Reshape, new_shape_const_values[info.non_negative_axis]:" + + std::to_string(new_shape_const_values[info.non_negative_axis]) + + ", info.output_dim_on_axis.has_dim_value(): " + + std::to_string(info.output_dim_on_axis.has_dim_value()) + "."); + return false; } @@ -604,11 +652,12 @@ bool ReshapeGatherActor::PostProcess( return true; } - // If it selected shape is a dim value, we can update the shape tensor directory. + // If the selected shape is a dim value, we can update the shape tensor directory. if (info_without_node.output_dim_on_axis.has_dim_value()) { new_shape_const_values[slice_axis] = info_without_node.output_dim_on_axis.dim_value(); auto new_shape_arg = - CreateInitializerFromVector(graph, {static_cast(new_shape_const_values.size())}, new_shape_const_values, + CreateInitializerFromVector(graph, {static_cast(new_shape_const_values.size())}, + new_shape_const_values, graph.GenerateNodeArgName(current_node.MutableInputDefs()[1]->Name())); graph_utils::ReplaceNodeInput(current_node, 1, *new_shape_arg); return true; diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.h b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.h index f6715e4bb1..0c21be1397 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.h +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.h @@ -189,7 +189,7 @@ class LayerNormalizationGatherActor : public UpStreamGatherOperatorActorBase { const logging::Logger& /* logger */, const std::unordered_map& /* propagate_input_indices */, const std::unordered_map>& /* all_input_cmp_rets */, - const std::unordered_map& /* new_gather_infos */) override { return true; } + const std::unordered_map& /* new_gather_infos */) override; }; class SoftmaxGatherActor : public SimplePointwiseGatherActor { @@ -202,6 +202,12 @@ class SoftmaxGatherActor : public SimplePointwiseGatherActor { std::unordered_map& propagate_input_indices, std::unordered_map>& all_input_cmp_rets, std::function& shape_update_func) override; + + bool PostProcess(Graph& /* graph */, Node& /* current_node */, const SliceInfo& /* info_without_node */, + const logging::Logger& /* logger */, + const std::unordered_map& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_gather_infos */) override; }; class ReshapeGatherActor : public UpStreamGatherOperatorActorBase { diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 18010960e1..3d54bbd474 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -958,6 +958,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Scan); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Shape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kQuadricDomain, 1, QuadricCustomOp); + // !!PLEASE READ BELOW!! Following that, add new entries above this comment /* *** IMPORTANT! *** @@ -2383,6 +2385,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h index aba9a798cf..b9f3050e59 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h @@ -64,6 +64,18 @@ enum MissingTrack : uint8_t { kFalse = 0 }; +template +struct TreeNodeElement; + +template +union PtrOrWeight { + TreeNodeElement* ptr; + struct WeightData { + int32_t weight; + int32_t n_weights; + } weight_data; +}; + template struct TreeNodeElement { int feature_id; @@ -71,24 +83,19 @@ struct TreeNodeElement { // Stores the node threshold or the weights if the tree has one target. T value_or_unique_weight; - // onnx specification says hitrates is used to store information about the node, + // The onnx specification says hitrates is used to store information about the node, // but this information is not used for inference. // T hitrates; - // True node, false node are obtained by computing `this + truenode_inc_or_first_weight`, - // `this + falsenode_inc_or_n_weights` if the node is not a leaf. - // In case of a leaf, these attributes are used to indicate the position of the weight - // in array `TreeEnsembleCommon::weights_`. If the number of targets or classes is one, - // the weight is also stored in `value_or_unique_weight`. - // This implementation assumes a tree has less than 2^31 nodes, - // and the total number of leave in the set of trees is below 2^31. - // A node cannot point to itself. - int32_t truenode_inc_or_first_weight; - // In case of a leaf, the following attribute indicates the number of weights - // in array `TreeEnsembleCommon::weights_`. If not a leaf, it indicates - // `this + falsenode_inc_or_n_weights` is the false node. - // A node cannot point to itself. - int32_t falsenode_inc_or_n_weights; + // PtrOrWeight acts as a tagged union, with the "tag" being whether the node is a leaf or not (see `is_not_leaf`). + + // If it is not a leaf, it is a pointer to the true child node when traversing the decision tree. The false branch is + // always 1 position away from the TreeNodeElement in practice in `TreeEnsembleCommon::nodes_` so it is not stored. + + // If it is a leaf, it contains `weight` and `n_weights` attributes which are used to indicate the position of the + // weight in array `TreeEnsembleCommon::weights_`. If the number of targets or classes is one, the weight is also + // stored in `value_or_unique_weight`. + PtrOrWeight truenode_or_weight; uint8_t flags; inline NODE_MODE mode() const { return NODE_MODE(flags & 0xF); } @@ -189,8 +196,8 @@ class TreeAggregatorSum : public TreeAggregator>& predictions, const TreeNodeElement& root, gsl::span> weights) const { - auto it = weights.begin() + root.truenode_inc_or_first_weight; - for (int32_t i = 0; i < root.falsenode_inc_or_n_weights; ++i, ++it) { + auto it = weights.begin() + root.truenode_or_weight.weight_data.weight; + for (int32_t i = 0; i < root.truenode_or_weight.weight_data.n_weights; ++i, ++it) { ORT_ENFORCE(it->i < (int64_t)predictions.size()); predictions[onnxruntime::narrow(it->i)].score += it->value; predictions[onnxruntime::narrow(it->i)].has_score = 1; @@ -292,8 +299,8 @@ class TreeAggregatorMin : public TreeAggregator>& predictions, const TreeNodeElement& root, gsl::span> weights) const { - auto it = weights.begin() + root.truenode_inc_or_first_weight; - for (int32_t i = 0; i < root.falsenode_inc_or_n_weights; ++i, ++it) { + auto it = weights.begin() + root.truenode_or_weight.weight_data.weight; + for (int32_t i = 0; i < root.truenode_or_weight.weight_data.n_weights; ++i, ++it) { predictions[onnxruntime::narrow(it->i)].score = (!predictions[onnxruntime::narrow(it->i)].has_score || it->value < predictions[onnxruntime::narrow(it->i)].score) ? it->value @@ -349,8 +356,8 @@ class TreeAggregatorMax : public TreeAggregator>& predictions, const TreeNodeElement& root, gsl::span> weights) const { - auto it = weights.begin() + root.truenode_inc_or_first_weight; - for (int32_t i = 0; i < root.falsenode_inc_or_n_weights; ++i, ++it) { + auto it = weights.begin() + root.truenode_or_weight.weight_data.weight; + for (int32_t i = 0; i < root.truenode_or_weight.weight_data.n_weights; ++i, ++it) { predictions[onnxruntime::narrow(it->i)].score = (!predictions[onnxruntime::narrow(it->i)].has_score || it->value > predictions[onnxruntime::narrow(it->i)].score) ? it->value diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h index 161bb2b082..8f847fe66a 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h @@ -85,6 +85,13 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes { template void ComputeAgg(concurrency::ThreadPool* ttp, const Tensor* X, Tensor* Y, Tensor* label, const AGG& agg) const; + + private: + size_t AddNodes(const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids, + const InlinedVector& falsenode_ids, const std::vector& nodes_featureids, + const std::vector& nodes_values_as_tensor, const std::vector& node_values, + const std::vector& nodes_missing_value_tracks_true, std::vector& updated_mapping, + int64_t tree_id, const InlinedVector& node_tree_ids); }; template @@ -186,7 +193,7 @@ Status TreeEnsembleCommon::Init( max_tree_depth_ = 1000; ORT_ENFORCE(nodes_modes.size() < std::numeric_limits::max()); - // additional members + // Additional members size_t limit; uint32_t i; InlinedVector cmodes; @@ -195,18 +202,14 @@ Status TreeEnsembleCommon::Init( int fpos = -1; for (i = 0, limit = nodes_modes.size(); i < limit; ++i) { cmodes.push_back(MakeTreeNodeMode(nodes_modes[i])); - if (cmodes[i] == NODE_MODE::LEAF) - continue; + if (cmodes[i] == NODE_MODE::LEAF) continue; if (fpos == -1) { fpos = static_cast(i); continue; } - if (cmodes[i] != cmodes[fpos]) - same_mode_ = false; + if (cmodes[i] != cmodes[fpos]) same_mode_ = false; } - // filling nodes - n_nodes_ = nodes_treeids.size(); limit = static_cast(n_nodes_); InlinedVector node_tree_ids; @@ -214,156 +217,185 @@ Status TreeEnsembleCommon::Init( nodes_.clear(); nodes_.reserve(limit); roots_.clear(); - std::unordered_map idi; - idi.reserve(limit); + std::unordered_map node_tree_ids_map; + node_tree_ids_map.reserve(limit); + + InlinedVector truenode_ids, falsenode_ids; + truenode_ids.reserve(limit); + falsenode_ids.reserve(limit); max_feature_id_ = 0; + // Build node_tree_ids and node_tree_ids_map and truenode_ids and falsenode_ids for (i = 0; i < limit; ++i) { - TreeNodeElementId node_tree_id{static_cast(nodes_treeids[i]), - static_cast(nodes_nodeids[i])}; - TreeNodeElement node; - node.feature_id = static_cast(nodes_featureids[i]); - if (node.feature_id > max_feature_id_) { - max_feature_id_ = node.feature_id; - } - node.value_or_unique_weight = nodes_values_as_tensor.empty() - ? static_cast(nodes_values[i]) - : nodes_values_as_tensor[i]; - - /* hitrates is not used for inference, they are ignored. - if (nodes_hitrates_as_tensor.empty()) { - node.hitrates = static_cast(i < nodes_hitrates.size() ? nodes_hitrates[i] : -1); - } else { - node.hitrates = i < nodes_hitrates_as_tensor.size() ? nodes_hitrates_as_tensor[i] : -1; - } */ - - node.flags = static_cast(cmodes[i]); - node.truenode_inc_or_first_weight = 0; // nodes_truenodeids[i] if not a leaf - node.falsenode_inc_or_n_weights = 0; // nodes_falsenodeids[i] if not a leaf - - if (i < static_cast(nodes_missing_value_tracks_true.size()) && nodes_missing_value_tracks_true[i] == 1) { - node.flags |= static_cast(MissingTrack::kTrue); - } - auto p = idi.insert(std::pair(node_tree_id, i)); + TreeNodeElementId node_tree_id{static_cast(nodes_treeids[i]), static_cast(nodes_nodeids[i])}; + auto p = node_tree_ids_map.insert(std::pair(node_tree_id, i)); if (!p.second) { ORT_THROW("Node ", node_tree_id.node_id, " in tree ", node_tree_id.tree_id, " is already there."); } - nodes_.emplace_back(node); node_tree_ids.emplace_back(node_tree_id); } - InlinedVector truenode_ids, falsenode_ids; - truenode_ids.reserve(limit); - falsenode_ids.reserve(limit); TreeNodeElementId coor; - i = 0; - for (auto it = nodes_.begin(); it != nodes_.end(); ++it, ++i) { - if (!it->is_not_leaf()) { + for (i = 0; i < limit; ++i) { + if (cmodes[i] == NODE_MODE::LEAF) { truenode_ids.push_back(0); falsenode_ids.push_back(0); - continue; - } - - TreeNodeElementId& node_tree_id = node_tree_ids[i]; - coor.tree_id = node_tree_id.tree_id; - coor.node_id = static_cast(nodes_truenodeids[i]); - ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_)); + } else { + TreeNodeElementId& node_tree_id = node_tree_ids[i]; + coor.tree_id = node_tree_id.tree_id; + coor.node_id = static_cast(nodes_truenodeids[i]); + ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_)); + + auto found = node_tree_ids_map.find(coor); + if (found == node_tree_ids_map.end()) { + ORT_THROW("Unable to find node ", coor.tree_id, "-", coor.node_id, " (truenode)."); + } + if (found->second == truenode_ids.size()) { + ORT_THROW("A node cannot point to itself: ", coor.tree_id, "-", node_tree_id.node_id, " (truenode)."); + } + truenode_ids.emplace_back(found->second); - auto found = idi.find(coor); - if (found == idi.end()) { - ORT_THROW("Unable to find node ", coor.tree_id, "-", coor.node_id, " (truenode)."); - } - if (found->second == truenode_ids.size()) { - ORT_THROW("A node cannot point to itself: ", coor.tree_id, "-", node_tree_id.node_id, " (truenode)."); + coor.node_id = static_cast(nodes_falsenodeids[i]); + ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_)); + found = node_tree_ids_map.find(coor); + if (found == node_tree_ids_map.end()) { + ORT_THROW("Unable to find node ", coor.tree_id, "-", coor.node_id, " (falsenode)."); + } + if (found->second == falsenode_ids.size()) { + ORT_THROW("A node cannot point to itself: ", coor.tree_id, "-", node_tree_id.node_id, " (falsenode)."); + } + falsenode_ids.emplace_back(found->second); + // We could also check that truenode_ids[truenode_ids.size() - 1] != falsenode_ids[falsenode_ids.size() - 1]). + // It is valid but no training algorithm would produce a tree where left and right nodes are the same. } - truenode_ids.emplace_back(found->second); + } - coor.node_id = static_cast(nodes_falsenodeids[i]); - ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_)); - found = idi.find(coor); - if (found == idi.end()) { - ORT_THROW("Unable to find node ", coor.tree_id, "-", coor.node_id, " (falsenode)."); - } - if (found->second == falsenode_ids.size()) { - ORT_THROW("A node cannot point to itself: ", coor.tree_id, "-", node_tree_id.node_id, " (falsenode)."); + // Let's construct nodes_ such that the false branch is always the next element in nodes_. + // updated_mapping will translates the old position of each node to the new node position in nodes_. + std::vector updated_mapping(nodes_treeids.size(), 0); + int64_t previous_tree_id = -1; + for (i = 0; i < n_nodes_; ++i) { + if (previous_tree_id == -1 || (previous_tree_id != node_tree_ids[i].tree_id)) { + // New tree. + int64_t tree_id = node_tree_ids[i].tree_id; + size_t root_position = + AddNodes(i, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, nodes_values, + nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids); + roots_.push_back(&nodes_[root_position]); + previous_tree_id = tree_id; } - falsenode_ids.emplace_back(found->second); - // We could also check that truenode_ids[truenode_ids.size() - 1] != falsenode_ids[falsenode_ids.size() - 1]). - // It is valid but no training algorithm would produce a tree where left and right nodes are the same. } - // sort targets + n_trees_ = roots_.size(); + if (((int64_t)nodes_.size()) != n_nodes_) { + ORT_THROW("Number of nodes in nodes_ (", nodes_.size(), ") is different from n_nodes (", n_nodes_, ")."); + } + + // Sort targets InlinedVector> indices; indices.reserve(target_class_nodeids.size()); for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) { - indices.emplace_back(std::pair( - TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]}, - i)); + indices.emplace_back( + std::pair(TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]}, i)); } + std::sort(indices.begin(), indices.end()); - // Initialize the leaves. TreeNodeElementId ind; SparseValue w; size_t indi; for (indi = 0, limit = target_class_nodeids.size(); indi < limit; ++indi) { ind = indices[indi].first; i = indices[indi].second; - auto found = idi.find(ind); - if (found == idi.end()) { + auto found = node_tree_ids_map.find(ind); + if (found == node_tree_ids_map.end()) { ORT_THROW("Unable to find node ", ind.tree_id, "-", ind.node_id, " (weights)."); } - TreeNodeElement& leaf = nodes_[found->second]; + TreeNodeElement& leaf = nodes_[updated_mapping[found->second]]; if (leaf.is_not_leaf()) { // An exception should be raised in that case. But this case may happen in // models converted with an old version of onnxmltools. These weights are ignored. // ORT_THROW("Node ", ind.tree_id, "-", ind.node_id, " is not a leaf."); continue; } - w.i = target_class_ids[i]; - w.value = target_class_weights_as_tensor.empty() - ? static_cast(target_class_weights[i]) - : target_class_weights_as_tensor[i]; - if (leaf.falsenode_inc_or_n_weights == 0) { - leaf.truenode_inc_or_first_weight = static_cast(weights_.size()); + w.value = target_class_weights_as_tensor.empty() ? static_cast(target_class_weights[i]) + : target_class_weights_as_tensor[i]; + if (leaf.truenode_or_weight.weight_data.n_weights == 0) { + leaf.truenode_or_weight.weight_data.weight = static_cast(weights_.size()); leaf.value_or_unique_weight = w.value; } - ++leaf.falsenode_inc_or_n_weights; + ++leaf.truenode_or_weight.weight_data.n_weights; weights_.push_back(w); } - // Initialize all the nodes but the leaves. - int64_t previous = -1; - for (i = 0, limit = static_cast(n_nodes_); i < limit; ++i) { - if ((previous == -1) || (previous != node_tree_ids[i].tree_id)) - roots_.push_back(&(nodes_[idi[node_tree_ids[i]]])); - previous = node_tree_ids[i].tree_id; - if (!nodes_[i].is_not_leaf()) { - if (nodes_[i].falsenode_inc_or_n_weights == 0) { - ORT_THROW("Target is missing for leaf ", ind.tree_id, "-", ind.node_id, "."); - } - continue; - } - ORT_ENFORCE(truenode_ids[i] != i); // That would mean the left node is itself, leading to an infinite loop. - nodes_[i].truenode_inc_or_first_weight = static_cast(truenode_ids[i] - i); - ORT_ENFORCE(falsenode_ids[i] != i); // That would mean the right node is itself, leading to an infinite loop. - nodes_[i].falsenode_inc_or_n_weights = static_cast(falsenode_ids[i] - i); - } - - n_trees_ = roots_.size(); has_missing_tracks_ = false; - for (auto itm = nodes_missing_value_tracks_true.begin(); - itm != nodes_missing_value_tracks_true.end(); ++itm) { + for (auto itm = nodes_missing_value_tracks_true.begin(); itm != nodes_missing_value_tracks_true.end(); ++itm) { if (*itm) { has_missing_tracks_ = true; break; } } + return Status::OK(); } +template +size_t TreeEnsembleCommon::AddNodes( + const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids, + const InlinedVector& falsenode_ids, const std::vector& nodes_featureids, + const std::vector& nodes_values_as_tensor, const std::vector& node_values, + const std::vector& nodes_missing_value_tracks_true, std::vector& updated_mapping, int64_t tree_id, + const InlinedVector& node_tree_ids) { + // Validate this index maps to the same tree_id as the one we should be building. + if (node_tree_ids[i].tree_id != tree_id) { + ORT_THROW("Tree id mismatch. Expected ", tree_id, " but got ", node_tree_ids[i].tree_id, " at position ", i); + } + + if (updated_mapping[i] != 0) { + // In theory we should not accept any cycles, however in practice LGBM conversion implements set membership via a + // series of "Equals" nodes, with the true branches directed at the same child node (a cycle). + // We may instead seek to formalize set membership in the future. + return updated_mapping[i]; + } + + size_t node_pos = nodes_.size(); + updated_mapping[i] = node_pos; + + TreeNodeElement node; + node.flags = static_cast(cmodes[i]); + node.feature_id = static_cast(nodes_featureids[i]); + if (node.feature_id > max_feature_id_) { + max_feature_id_ = node.feature_id; + } + node.value_or_unique_weight = + nodes_values_as_tensor.empty() ? static_cast(node_values[i]) : nodes_values_as_tensor[i]; + if (i < static_cast(nodes_missing_value_tracks_true.size()) && nodes_missing_value_tracks_true[i] == 1) { + node.flags |= static_cast(MissingTrack::kTrue); + } + nodes_.push_back(std::move(node)); + if (nodes_[node_pos].is_not_leaf()) { + size_t false_branch = + AddNodes(falsenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, + node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids); + if (false_branch != node_pos + 1) { + ORT_THROW("False node must always be the next node, but it isn't at index ", node_pos, " with flags ", + static_cast(nodes_[node_pos].flags)); + } + size_t true_branch = + AddNodes(truenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, + node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids); + // We don't need to store the false branch pointer since we know it is always in the immediate next entry in nodes_. + // nodes_[node_pos].falsenode_inc_or_n_weights.ptr = &nodes_[false_branch]; + nodes_[node_pos].truenode_or_weight.ptr = &nodes_[true_branch]; + } else { + nodes_[node_pos].truenode_or_weight.weight_data.weight = 0; + nodes_[node_pos].truenode_or_weight.weight_data.n_weights = 0; + } + return node_pos; +} + template Status TreeEnsembleCommon::compute(OpKernelContext* ctx, const Tensor* X, @@ -637,22 +669,19 @@ void TreeEnsembleCommon::ComputeAgg(concur } } // namespace detail -#define TREE_FIND_VALUE(CMP) \ - if (has_missing_tracks_) { \ - while (root->is_not_leaf()) { \ - val = x_data[root->feature_id]; \ - root += (val CMP root->value_or_unique_weight || \ - (root->is_missing_track_true() && _isnan_(val))) \ - ? root->truenode_inc_or_first_weight \ - : root->falsenode_inc_or_n_weights; \ - } \ - } else { \ - while (root->is_not_leaf()) { \ - val = x_data[root->feature_id]; \ - root += val CMP root->value_or_unique_weight \ - ? root->truenode_inc_or_first_weight \ - : root->falsenode_inc_or_n_weights; \ - } \ +#define TREE_FIND_VALUE(CMP) \ + if (has_missing_tracks_) { \ + while (root->is_not_leaf()) { \ + val = x_data[root->feature_id]; \ + root = (val CMP root->value_or_unique_weight || (root->is_missing_track_true() && _isnan_(val))) \ + ? root->truenode_or_weight.ptr \ + : root + 1; \ + } \ + } else { \ + while (root->is_not_leaf()) { \ + val = x_data[root->feature_id]; \ + root = val CMP root->value_or_unique_weight ? root->truenode_or_weight.ptr : root + 1; \ + } \ } inline bool _isnan_(float x) { return std::isnan(x); } @@ -671,15 +700,14 @@ TreeEnsembleCommon::ProcessTreeNodeLeave( if (has_missing_tracks_) { while (root->is_not_leaf()) { val = x_data[root->feature_id]; - root += (val <= root->value_or_unique_weight || - (root->is_missing_track_true() && _isnan_(val))) - ? root->truenode_inc_or_first_weight - : root->falsenode_inc_or_n_weights; + root = (val <= root->value_or_unique_weight || (root->is_missing_track_true() && _isnan_(val))) + ? root->truenode_or_weight.ptr + : root + 1; } } else { while (root->is_not_leaf()) { val = x_data[root->feature_id]; - root += val <= root->value_or_unique_weight ? root->truenode_inc_or_first_weight : root->falsenode_inc_or_n_weights; + root = val <= root->value_or_unique_weight ? root->truenode_or_weight.ptr : root + 1; } } break; @@ -703,42 +731,36 @@ TreeEnsembleCommon::ProcessTreeNodeLeave( } } else { // Different rules to compare to node thresholds. ThresholdType threshold; - while (root->is_not_leaf()) { + while (1) { val = x_data[root->feature_id]; threshold = root->value_or_unique_weight; switch (root->mode()) { case NODE_MODE::BRANCH_LEQ: - root += val <= threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight - : root->falsenode_inc_or_n_weights; + root = val <= threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr + : root + 1; break; case NODE_MODE::BRANCH_LT: - root += val < threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight - : root->falsenode_inc_or_n_weights; + root = val < threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr + : root + 1; break; case NODE_MODE::BRANCH_GTE: - root += val >= threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight - : root->falsenode_inc_or_n_weights; + root = val >= threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr + : root + 1; break; case NODE_MODE::BRANCH_GT: - root += val > threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight - : root->falsenode_inc_or_n_weights; + root = val > threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr + : root + 1; break; case NODE_MODE::BRANCH_EQ: - root += val == threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight - : root->falsenode_inc_or_n_weights; + root = val == threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr + : root + 1; break; case NODE_MODE::BRANCH_NEQ: - root += val != threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight - : root->falsenode_inc_or_n_weights; + root = val != threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr + : root + 1; break; case NODE_MODE::LEAF: - break; + return root; } } } diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm.h b/onnxruntime/core/providers/cpu/nn/batch_norm.h index a5c68eebb2..be9bc3368e 100644 --- a/onnxruntime/core/providers/cpu/nn/batch_norm.h +++ b/onnxruntime/core/providers/cpu/nn/batch_norm.h @@ -29,6 +29,10 @@ namespace onnxruntime { +#if !defined(ORT_MINIMAL_BUILD) || defined(ENABLE_TRAINING_OPS) +#define BATCHNORM_INCLUDE_TRAINING_SUPPORT +#endif + template class BatchNorm : public OpKernel { public: @@ -47,7 +51,7 @@ class BatchNorm : public OpKernel { } if (is_train_) { -#ifdef ENABLE_TRAINING_OPS +#if defined(BATCHNORM_INCLUDE_TRAINING_SUPPORT) momentum_ = op_kernel_info.GetAttrOrDefault("momentum", 0.9f); ORT_ENFORCE(is_spatial_, "Training mode only supports spatial BN"); #else @@ -84,7 +88,7 @@ class BatchNorm : public OpKernel { // calculate sample_size (including all channels) size_t sample_size_incl_all_channels = sample_size * C; -#ifdef ENABLE_TRAINING_OPS +#if defined(BATCHNORM_INCLUDE_TRAINING_SUPPORT) AllocatorPtr alloc; ORT_RETURN_IF_ERROR(p_op_kernel_context->GetTempSpaceAllocator(&alloc)); @@ -111,7 +115,7 @@ class BatchNorm : public OpKernel { ConstEigenVectorArrayMap scale_arr(scale->Data(), is_spatial_ ? C : sample_size_incl_all_channels); ConstEigenVectorArrayMap bias_arr(B->Data(), is_spatial_ ? C : sample_size_incl_all_channels); -#ifdef ENABLE_TRAINING_OPS +#if defined(BATCHNORM_INCLUDE_TRAINING_SUPPORT) // Note that we only support spatial BN for training if (is_train_) { EigenVectorArrayMap saved_mean_arr(saved_mean->MutableData(), C); @@ -162,7 +166,7 @@ class BatchNorm : public OpKernel { ConstEigenVectorArrayMap var_arr(var->Data(), is_spatial_ ? C : sample_size_incl_all_channels); inv_std = (var_arr + epsilon_).sqrt().inverse(); } else { -#ifdef ENABLE_TRAINING_OPS +#if defined(BATCHNORM_INCLUDE_TRAINING_SUPPORT) EigenVectorArrayMap saved_inv_std_arr(saved_inv_std->MutableData(), C); saved_inv_std_arr = (saved_inv_std_arr + epsilon_).inverse().sqrt(); inv_std = saved_inv_std_arr; @@ -171,7 +175,7 @@ class BatchNorm : public OpKernel { // If we're training, do batch normalization based on computation from this batch ConstEigenVectorArrayMap mean_arr( -#ifdef ENABLE_TRAINING_OPS +#if defined(BATCHNORM_INCLUDE_TRAINING_SUPPORT) !is_train_ ? mean->Data() : saved_mean->Data(), #else mean->Data(), diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h b/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h index a4d67ec63f..984f4795ce 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h @@ -44,14 +44,16 @@ struct ConvTransposeAttributes : public ConvAttributes { }; Status PrepareForCompute(OpKernelContext* context, bool has_bias, Prepare& p, - bool dynamic_padding = false, const TensorShape* filter_shape = nullptr) const { + bool dynamic_padding = false, const TensorShape* filter_shape = nullptr, bool is_quant = false) const { const Tensor* X = context->Input(0); - const Tensor* F = (filter_shape != nullptr) ? nullptr : context->Input(1); + const Tensor* F = (filter_shape != nullptr) ? nullptr : context->Input(is_quant ? 3 : 1); const TensorShape& F_Shape = (filter_shape != nullptr) ? *filter_shape : F->Shape(); + if (dynamic_padding && is_quant) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "dynamic padding is not supported for quantized conv tranpose."); + } const Tensor* Pads = dynamic_padding ? context->Input(2) : nullptr; - const Tensor* B = has_bias ? (dynamic_padding ? context->Input(3) : context->Input(2)) : nullptr; + const Tensor* B = has_bias ? (dynamic_padding ? context->Input(is_quant ? 9 : 3) : context->Input(is_quant ? 8 : 2)) : nullptr; TensorShape input_shape = X->Shape().Slice(2); - const int64_t num_input_channels = X->Shape()[1]; const int64_t N = X->Shape()[0]; const int64_t num_output_channels_multiplier = F_Shape[1]; diff --git a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc index e9fc8d857b..21a256eee6 100644 --- a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc +++ b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc @@ -77,7 +77,8 @@ class QLinearConv : public OpKernel { W_zero_point_value = W_zero_point_data[0]; for (int64_t i = 1; i < W_zero_point_size; i++) { ORT_ENFORCE(W_zero_point_data[i] == W_zero_point_value, - "QLinearConv : zero point of per-channel filter must be same"); + "QLinearConv : zero point of per-channel filter must be same. " + "This happens by design if the quantization is symmetric."); } } diff --git a/onnxruntime/core/providers/cpu/quantization/qlinearconvtranspose.cc b/onnxruntime/core/providers/cpu/quantization/qlinearconvtranspose.cc new file mode 100644 index 0000000000..513f72030b --- /dev/null +++ b/onnxruntime/core/providers/cpu/quantization/qlinearconvtranspose.cc @@ -0,0 +1,254 @@ +/** +* Copyright (c) 2014-present, Quadric, Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +/* Modifications Copyright (c) Microsoft. */ + +#include "core/mlas/inc/mlas.h" +#include "core/common/safeint.h" +#include "core/util/math.h" +#include "core/util/math_cpuonly.h" + +#include "core/common/inlined_containers_fwd.h" +#include "core/framework/transpose_helper.h" +#include "core/providers/utils.h" +#include "core/framework/tensorprotoutils.h" + +#include "core/providers/cpu/nn/conv_transpose_attributes.h" + +namespace onnxruntime { + +template +class QLinearConvTranspose : public OpKernel { + public: + explicit QLinearConvTranspose(const OpKernelInfo& info) : OpKernel(info), conv_transpose_attrs_(info) { + } + + Status Compute(OpKernelContext* context) const override; + + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override; + + Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + int input_idx, + /*out*/ bool& used_shared_buffers) override; + + protected: + Status DoConvTranspose(OpKernelContext* context) const; + + private: + static float ComputeOutputScale(OpKernelContext* context) { + const Tensor* X_scale = context->Input(InputTensors::IN_X_SCALE); + const Tensor* W_scale = context->Input(InputTensors::IN_W_SCALE); + const Tensor* Y_scale = context->Input(InputTensors::IN_Y_SCALE); + ORT_ENFORCE(IsScalarOr1ElementVector(X_scale), + "QLinearConv : input scale must be a scalar or 1D tensor of size 1"); + ORT_ENFORCE(IsScalarOr1ElementVector(Y_scale), + "QLinearConv : result scale must be a scalar or 1D tensor of size 1"); + ORT_ENFORCE(IsScalarOr1ElementVector(W_scale), + "QLinearConv : filter scale must be a scalar or 1D tensor of size 1"); + + auto X_scale_value = *(X_scale->Data()); + auto Y_scale_value = *(Y_scale->Data()); + auto W_scale_value = *(W_scale->Data()); + + return X_scale_value * W_scale_value / Y_scale_value; + } + + enum InputTensors : int { + IN_X = 0, + IN_X_SCALE = 1, + IN_X_ZERO_POINT = 2, + IN_W = 3, + IN_W_SCALE = 4, + IN_W_ZERO_POINT = 5, + IN_Y_SCALE = 6, + IN_Y_ZERO_POINT = 7, + IN_BIAS = 8 + + }; + enum OutputTensors : int { + OUT_Y = 0 + }; + + ConvTransposeAttributes conv_transpose_attrs_; + + // for pre-packing usage + TensorShape filter_shape_; + BufferUniquePtr transposed_filter_; +}; + +template +Status QLinearConvTranspose::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* /*prepacked_weights*/ +) { + is_packed = false; + return Status::OK(); +} + +template +Status QLinearConvTranspose::UseSharedPrePackedBuffers(std::vector& /*prepacked_buffers*/, + int /*input_idx*/, + /*out*/ bool& used_shared_buffers) { + used_shared_buffers = false; + return Status::OK(); +} + +template +Status QLinearConvTranspose::Compute(OpKernelContext* context) const { + return QLinearConvTranspose::DoConvTranspose(context); +} + +template +Status QLinearConvTranspose::DoConvTranspose(OpKernelContext* context) const { + typedef int32_t ActType; + size_t num_inputs = OpKernel::Node().InputDefs().size(); + ConvTransposeAttributes::Prepare p; + bool has_bias = num_inputs == 9; + ORT_RETURN_IF_ERROR(conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, false, nullptr, true)); + + const int64_t input_image_size = p.input_shape.Size(); + + // Bail out early if one of the dimensions is zero. + if (p.Y->Shape().Size() == 0) { + return Status::OK(); + } + + // Quantization parameters, only support symmetric quant for now + float scale_value = ComputeOutputScale(context); + + const int64_t X_offset = p.num_input_channels / conv_transpose_attrs_.group * input_image_size; + const int64_t Y_offset = p.Y->Shape().Size() / p.Y->Shape()[0] / conv_transpose_attrs_.group; + const int64_t W_offset = p.F->Shape().Size() / conv_transpose_attrs_.group; + const int64_t kernel_size = TensorShape(p.kernel_shape).Size(); + const int64_t kernel_dim = p.num_output_channels / conv_transpose_attrs_.group * kernel_size; + const int64_t output_size = (p.Y->Shape().Slice(2)).Size(); + + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); + + const int64_t col_buffer_size = kernel_dim * p.input_shape.Size(); + auto col_data = alloc->Alloc(SafeInt(sizeof(ActType)) * col_buffer_size); + BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc)); + + // Pre-transpose the weight matrix because MatMul does not take transpose params + const int64_t trans_filt_size = p.num_input_channels / conv_transpose_attrs_.group * kernel_dim; + auto trans_filt_data = alloc->Alloc(SafeInt(sizeof(T)) * trans_filt_size); + BufferUniquePtr trans_filt(trans_filt_data, BufferDeleter(alloc)); + ActType* col_buffer_data = static_cast(col_buffer.get()); + + const T* Xdata = p.X->Data(); + const T* filter_data = p.F->Data(); + T* Ydata = p.Y->MutableData(); + TensorShape output_shape = p.Y->Shape().Slice(2); + MlasTranspose( + filter_data, + static_cast(trans_filt.get()), + p.num_input_channels / conv_transpose_attrs_.group, + kernel_dim); + + //TODO: add support for assymmetric quantization and using int8 MlassGemm. This will require offseting the + //input data to an uint8 range and passing a zero point parameter. + //For now compute the GEMM in int32, as MlassGemm requires uint8 input and MlasSymmQgemmBatch is ARM only + auto inp_i32_buffer = alloc->Alloc(SafeInt(sizeof(ActType)) * p.X->Shape().Size()); + BufferUniquePtr inp_i32(inp_i32_buffer, BufferDeleter(alloc)); + ActType* inp_i32_data = static_cast(inp_i32.get()); + for (std::int64_t i = 0; i < p.X->Shape().Size(); i++) { + inp_i32_data[i] = Xdata[i]; + } + auto wt_i32_buffer = alloc->Alloc(SafeInt(sizeof(ActType)) * p.F->Shape().Size()); + BufferUniquePtr wt_i32(wt_i32_buffer, BufferDeleter(alloc)); + ActType* wt_i32_data = static_cast(wt_i32.get()); + for (std::int64_t i = 0; i < p.F->Shape().Size(); i++) { + wt_i32_data[i] = static_cast(trans_filt.get())[i]; + } + auto out_i32_buffer = alloc->Alloc(SafeInt(sizeof(ActType)) * p.Y->Shape().Size()); + BufferUniquePtr out_i32(out_i32_buffer, BufferDeleter(std::move(alloc))); + ActType* out_i32_data = static_cast(out_i32.get()); + + for (auto image_id = 0; image_id < p.N; ++image_id) { + for (int group_id = 0; group_id < conv_transpose_attrs_.group; ++group_id) { + math::MatMul(kernel_dim, input_image_size, p.num_input_channels / conv_transpose_attrs_.group, + wt_i32_data + group_id * W_offset, + inp_i32_data + group_id * X_offset, + col_buffer_data, + nullptr); + + math::Col2im( + reinterpret_cast(col_buffer_data), + p.num_output_channels / conv_transpose_attrs_.group, + p.Y->Shape()[2], + p.Y->Shape()[3], + p.kernel_shape[0], + p.kernel_shape[1], + p.dilations[0], + p.dilations[1], + p.pads[0], + p.pads[1], + p.pads[2], + p.pads[3], + p.strides[0], + p.strides[1], + reinterpret_cast(out_i32_data + group_id * Y_offset), + &CPUMathUtil::Instance()); + } + + if (p.B != nullptr) { + auto out_i32_matrix = EigenMatrixMap(out_i32_data, output_size, p.num_output_channels); + auto Bvec = ConstEigenVectorMap(p.B->Data(), p.num_output_channels); + out_i32_matrix.rowwise() += Bvec.transpose(); + } + + MlasRequantizeOutput(out_i32_data, + p.Y->Shape()[2] * p.Y->Shape()[3], + Ydata, + p.Y->Shape()[2] * p.Y->Shape()[3], + nullptr, + &scale_value, + false, + (T)0, + 0, + 0, + p.num_output_channels, + p.Y->Shape()[2] * p.Y->Shape()[3]); + Xdata += X_offset * conv_transpose_attrs_.group; + Ydata += Y_offset * conv_transpose_attrs_.group; + } + + return Status::OK(); +} + +#ifndef DISABLE_CONTRIB_OPS +namespace contrib { + +// Register the operator with int8 inputs and weights +ONNX_OPERATOR_KERNEL_EX( + QLinearConvTranspose, + kMSDomain, + 1, + //int8_t, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), + QLinearConvTranspose); + +} // namespace contrib +#endif + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/scatter.cc b/onnxruntime/core/providers/cpu/tensor/scatter.cc index f87788e8f4..8844b7e7a2 100644 --- a/onnxruntime/core/providers/cpu/tensor/scatter.cc +++ b/onnxruntime/core/providers/cpu/tensor/scatter.cc @@ -308,7 +308,7 @@ Status ScatterData( const auto& upd_shape = updates_input->Shape(); const auto num_dims = input_data_shape.NumDimensions(); - assert(num_dims > 0); + ORT_RETURN_IF_NOT(num_dims > 0, "ScatterElements op: input tensor must have at least one dimension"); // Allocate and zero out counts. The input/output is of the same rank as // indices/updates but the actual dimensions of indices/updates must be less or equal diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index a50b53315e..0d9928baa8 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -20,7 +20,7 @@ namespace cuda { // float16 arithmetic is supported after sm5.3 with intrinsics, and cuda does not provide fallback for lower versions // CUDA 12.2 does not limit the definition based on sm53 anymore and defines for all arches -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12 ) && (__CUDACC_VER_MINOR__ < 2))) +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) __device__ __forceinline__ half operator+(const half& lh, const half& rh) { return half((float)lh + (float)rh); } __device__ __forceinline__ half operator-(const half& lh, const half& rh) { return half((float)lh - (float)rh); } __device__ __forceinline__ half operator*(const half& lh, const half& rh) { return half((float)lh * (float)rh); } @@ -351,6 +351,18 @@ __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; } template __device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; } +template +__device__ __inline__ T _Signum(T a, std::false_type /* is_signed */) { return T(0) < a; } + +template +__device__ __inline__ T _Signum(T a, std::true_type /* is_signed */) { return (T(0) < a) - (a < T(0)); } + +template +__device__ __inline__ T _Sign(T a) { return _Signum(a, std::is_signed()); } + +template <> +__device__ __inline__ half _Sign(half a) { return _Signum(a, std::true_type()); } + template __device__ __inline__ T _Normcdf(T a); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index aa60db4d07..ad892eab3b 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1180,6 +1180,17 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int8_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int16_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint16_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint32_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint64_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub); @@ -2118,6 +2129,17 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index f026444328..9ede1f8d90 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -157,6 +157,7 @@ UNARY_OP_HFD(Sqrt, 13) UNARY_OP_HFD(Log, 13) UNARY_OP_HFD(Exp, 13) UNARY_OP_HFD(Erf, 13) +UNARY_OP_BWUZCSILHFD(Sign, 13) UNARY_LOGICALOP_NOT_TYPED(1, bool) UNARY_OP_HFD(Round, 11) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h index 3ff97a6011..775b78c43a 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h @@ -112,5 +112,12 @@ class Cos final : public UnaryElementwise { Status ComputeInternal(OpKernelContext* context) const override; }; +template +class Sign final : public UnaryElementwise { + public: + Sign(const OpKernelInfo& info) : UnaryElementwise(info) {} + Status ComputeInternal(OpKernelContext* context) const override; +}; + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index ac7cc1126a..1298d53338 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -90,6 +90,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sin) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Cos) SPECIALIZED_UNARY_ELEMENTWISE_IMPL(Not, bool) +SPECIALIZED_UNARY_ELEMENTWISE_IMPL_BWUZCSILHFD(Sign) // When casting, half needs to be converted via float type from most other types template @@ -119,52 +120,52 @@ struct OP_Cast { } }; -#define IMPL_CAST_IMPL(InT, OutT) \ +#define IMPL_CAST_IMPL(InT, OutT) \ void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \ - UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast(), count); \ + UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast(), count); \ } -#define IMPL_CAST_IMPL_THROW(InT, OutT) \ +#define IMPL_CAST_IMPL_THROW(InT, OutT) \ void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \ - ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \ + ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \ } #if !defined(DISABLE_FLOAT8_TYPES) -#define IMPL_CAST_IMPL_FROM(T) \ - IMPL_CAST_IMPL(T, half) \ - IMPL_CAST_IMPL(T, float) \ - IMPL_CAST_IMPL(T, double) \ - IMPL_CAST_IMPL(T, int8_t) \ - IMPL_CAST_IMPL(T, int16_t) \ - IMPL_CAST_IMPL(T, int32_t) \ - IMPL_CAST_IMPL(T, int64_t) \ - IMPL_CAST_IMPL(T, uint8_t) \ - IMPL_CAST_IMPL(T, uint16_t) \ - IMPL_CAST_IMPL(T, uint32_t) \ - IMPL_CAST_IMPL(T, uint64_t) \ - IMPL_CAST_IMPL(T, bool) \ - IMPL_CAST_IMPL(T, BFloat16) \ - IMPL_CAST_IMPL_THROW(T, Float8E4M3FN) \ - IMPL_CAST_IMPL_THROW(T, Float8E5M2) \ +#define IMPL_CAST_IMPL_FROM(T) \ + IMPL_CAST_IMPL(T, half) \ + IMPL_CAST_IMPL(T, float) \ + IMPL_CAST_IMPL(T, double) \ + IMPL_CAST_IMPL(T, int8_t) \ + IMPL_CAST_IMPL(T, int16_t) \ + IMPL_CAST_IMPL(T, int32_t) \ + IMPL_CAST_IMPL(T, int64_t) \ + IMPL_CAST_IMPL(T, uint8_t) \ + IMPL_CAST_IMPL(T, uint16_t) \ + IMPL_CAST_IMPL(T, uint32_t) \ + IMPL_CAST_IMPL(T, uint64_t) \ + IMPL_CAST_IMPL(T, bool) \ + IMPL_CAST_IMPL(T, BFloat16) \ + IMPL_CAST_IMPL_THROW(T, Float8E4M3FN) \ + IMPL_CAST_IMPL_THROW(T, Float8E5M2) \ IMPL_CAST_IMPL_THROW(T, Float8E4M3FNUZ) \ IMPL_CAST_IMPL_THROW(T, Float8E5M2FNUZ) #else -#define IMPL_CAST_IMPL_FROM(T) \ - IMPL_CAST_IMPL(T, half) \ - IMPL_CAST_IMPL(T, float) \ - IMPL_CAST_IMPL(T, double) \ - IMPL_CAST_IMPL(T, int8_t) \ - IMPL_CAST_IMPL(T, int16_t) \ - IMPL_CAST_IMPL(T, int32_t) \ - IMPL_CAST_IMPL(T, int64_t) \ - IMPL_CAST_IMPL(T, uint8_t) \ - IMPL_CAST_IMPL(T, uint16_t) \ - IMPL_CAST_IMPL(T, uint32_t) \ - IMPL_CAST_IMPL(T, uint64_t) \ - IMPL_CAST_IMPL(T, bool) \ +#define IMPL_CAST_IMPL_FROM(T) \ + IMPL_CAST_IMPL(T, half) \ + IMPL_CAST_IMPL(T, float) \ + IMPL_CAST_IMPL(T, double) \ + IMPL_CAST_IMPL(T, int8_t) \ + IMPL_CAST_IMPL(T, int16_t) \ + IMPL_CAST_IMPL(T, int32_t) \ + IMPL_CAST_IMPL(T, int64_t) \ + IMPL_CAST_IMPL(T, uint8_t) \ + IMPL_CAST_IMPL(T, uint16_t) \ + IMPL_CAST_IMPL(T, uint32_t) \ + IMPL_CAST_IMPL(T, uint64_t) \ + IMPL_CAST_IMPL(T, bool) \ IMPL_CAST_IMPL(T, BFloat16) #endif @@ -199,58 +200,58 @@ struct OP_CastNoSat { #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 -#define OP_CAST(T, NVT) \ - template <> \ - struct OP_CastSat { \ - __device__ __inline__ T operator()(const half& v) const { \ +#define OP_CAST(T, NVT) \ + template <> \ + struct OP_CastSat { \ + __device__ __inline__ T operator()(const half& v) const { \ return T(static_cast(__nv_cvt_halfraw_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \ - } \ - }; \ - template <> \ - struct OP_CastNoSat { \ - __device__ __inline__ T operator()(const half& v) const { \ - return T(static_cast(__nv_cvt_halfraw_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \ - } \ - }; \ - template <> \ - struct OP_CastSat { \ - __device__ __inline__ T operator()(const float& v) const { \ - return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \ - } \ - }; \ - template <> \ - struct OP_CastNoSat { \ - __device__ __inline__ T operator()(const float& v) const { \ - return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \ - } \ + } \ + }; \ + template <> \ + struct OP_CastNoSat { \ + __device__ __inline__ T operator()(const half& v) const { \ + return T(static_cast(__nv_cvt_halfraw_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \ + } \ + }; \ + template <> \ + struct OP_CastSat { \ + __device__ __inline__ T operator()(const float& v) const { \ + return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_SATFINITE, NVT)), T::FromBits()); \ + } \ + }; \ + template <> \ + struct OP_CastNoSat { \ + __device__ __inline__ T operator()(const float& v) const { \ + return T(static_cast(__nv_cvt_float_to_fp8(v, __NV_NOSAT, NVT)), T::FromBits()); \ + } \ }; #else -#define OP_CAST(T, NVT) \ - template <> \ - struct OP_CastSat { \ - __device__ __inline__ T operator()(const half& v) const { \ - return T(__half2float(v), true); \ - } \ - }; \ - template <> \ - struct OP_CastNoSat { \ - __device__ __inline__ T operator()(const half& v) const { \ - return T(__half2float(v), false); \ - } \ - }; \ - template <> \ - struct OP_CastSat { \ +#define OP_CAST(T, NVT) \ + template <> \ + struct OP_CastSat { \ + __device__ __inline__ T operator()(const half& v) const { \ + return T(__half2float(v), true); \ + } \ + }; \ + template <> \ + struct OP_CastNoSat { \ + __device__ __inline__ T operator()(const half& v) const { \ + return T(__half2float(v), false); \ + } \ + }; \ + template <> \ + struct OP_CastSat { \ __device__ __inline__ T operator()(const float& v) const { \ - return T(v, true); \ - } \ - }; \ - template <> \ - struct OP_CastNoSat { \ + return T(v, true); \ + } \ + }; \ + template <> \ + struct OP_CastNoSat { \ __device__ __inline__ T operator()(const float& v) const { \ - return T(v, false); \ - } \ + return T(v, false); \ + } \ }; #endif @@ -260,14 +261,13 @@ struct OP_CastNoSat { OP_CAST(Float8E4M3FN, __NV_E4M3) OP_CAST(Float8E5M2, __NV_E5M2) - -#define EXPLICIT_IMPL_CASTSAT(InT, OutT) \ +#define EXPLICIT_IMPL_CASTSAT(InT, OutT) \ void Explicit_Impl_CastSat(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count, bool saturate) { \ - if (saturate) { \ - UnaryElementWiseImpl(stream, input_data, output_data, OP_CastSat(), count); \ - } else { \ - UnaryElementWiseImpl(stream, input_data, output_data, OP_CastNoSat(), count); \ - } \ + if (saturate) { \ + UnaryElementWiseImpl(stream, input_data, output_data, OP_CastSat(), count); \ + } else { \ + UnaryElementWiseImpl(stream, input_data, output_data, OP_CastNoSat(), count); \ + } \ } EXPLICIT_IMPL_CASTSAT(float, Float8E4M3FN) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h index 3d4868b54a..608a81a24c 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h @@ -31,7 +31,8 @@ namespace cuda { UNARY_OP_NAME_EXPR(Not, !a) \ UNARY_OP_NAME_EXPR(Round, _Round(a)) \ UNARY_OP_NAME_EXPR(Sin, _Sin(a)) \ - UNARY_OP_NAME_EXPR(Cos, _Cos(a)) + UNARY_OP_NAME_EXPR(Cos, _Cos(a)) \ + UNARY_OP_NAME_EXPR(Sign, _Sign(a)) #define UNARY_ELEMENTWISE_IMPL_DECLARATION(name) \ template \ diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu index 4cc560a117..679b8b6b78 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu @@ -104,17 +104,17 @@ __device__ void cuWelfordMuSigma2( const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; const T* lvals = vals + i1 * n2; - const T* skip_vals = (skip != NULL) ? skip + i1 * n2 : NULL; + const T* skip_vals = (skip != nullptr) ? skip + i1 * n2 : nullptr; int l = 4 * thrx; for (; l + 3 < n2; l += 4 * numx) { for (int k = 0; k < 4; ++k) { U curr = static_cast(lvals[l + k]); - if (bias != NULL) { + if (bias != nullptr) { curr += static_cast(bias[l + k]); } - if (skip_vals != NULL) { + if (skip_vals != nullptr) { curr += static_cast(skip_vals[l + k]); } @@ -124,11 +124,11 @@ __device__ void cuWelfordMuSigma2( for (; l < n2; ++l) { U curr = static_cast(lvals[l]); - if (bias != NULL) { + if (bias != nullptr) { curr += static_cast(bias[l]); } - if (skip_vals != NULL) { + if (skip_vals != nullptr) { curr += static_cast(skip_vals[l]); } @@ -301,7 +301,7 @@ namespace { // { // extern __device__ void error(void); // error(); -// return NULL; +// return nullptr; // } // }; // https://github.com/NVIDIA/apex/issues/246 @@ -338,9 +338,7 @@ __global__ void cuApplyLayerNorm( const V* __restrict__ beta, const T* __restrict__ skip, const T* __restrict__ bias, - T* __restrict__ skip_input_bias_add_output, - const bool skip_broadcasted, - const int skip_size) { + T* __restrict__ skip_input_bias_add_output) { // Assumptions: // 1) blockDim.x == GPU_WARP_SIZE // 2) Tensors are contiguous @@ -350,38 +348,35 @@ __global__ void cuApplyLayerNorm( U* buf = shared.getPointer(); U mu, sigma2; cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf, skip, bias); - const T* lvals = vals + i1 * n2; - const T* skip_vals = (skip != NULL) ? skip + i1 * n2 : NULL; - V* ovals = output_vals + i1 * n2; - T* skip_input_bias_add_ovals = (skip_input_bias_add_output != NULL) ? skip_input_bias_add_output + i1 * n2 : NULL; + const int offset = i1 * n2; + const T* lvals = vals + offset; + const T* skip_vals = (skip != nullptr) ? skip + offset : nullptr; + + V* ovals = output_vals + offset; + T* skip_input_bias_add_ovals = (skip_input_bias_add_output != nullptr) ? skip_input_bias_add_output + offset : nullptr; U c_inv_std_dev = rsqrt(sigma2 + epsilon); const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; for (int i = thrx; i < n2; i += numx) { U curr = static_cast(lvals[i]); - - - if (bias != NULL) { + if (bias != nullptr) { curr += static_cast(bias[i]); } - if (skip_vals != NULL && skip_broadcasted) { - int skip_i = i % skip_size; - curr += static_cast(skip_vals[skip_i]); //Calculates index for the second dimension of the skip tensor - }else if (skip_vals != NULL){ + if (skip_vals != nullptr) { curr += static_cast(skip_vals[i]); } - U gamma_i = (gamma != NULL) ? (U)gamma[i] : (U)1; - U beta_i = (beta != NULL) ? (U)beta[i] : (U)0; + U gamma_i = (gamma != nullptr) ? (U)gamma[i] : (U)1; + U beta_i = (beta != nullptr) ? (U)beta[i] : (U)0; if (simplified) { ovals[i] = static_cast(gamma_i * c_inv_std_dev * curr); } else { ovals[i] = static_cast(gamma_i * c_inv_std_dev * (curr - mu) + beta_i); } - if (skip_input_bias_add_ovals != NULL) { + if (skip_input_bias_add_ovals != nullptr) { skip_input_bias_add_ovals[i] = static_cast(curr); } } @@ -418,9 +413,7 @@ void HostApplyLayerNorm( const V* beta, const T* skip, const T* bias, - T* skip_input_bias_add_output, - const bool skip_broadcasted, - const int skip_size) { + T* skip_input_bias_add_output) { const int maxGridY = prop.maxGridSize[1]; const int warp_size = prop.warpSize; ORT_ENFORCE(warp_size == GPU_WARP_SIZE_HOST); @@ -452,17 +445,14 @@ void HostApplyLayerNorm( n1, n2, U(epsilon), gamma, beta, - skip, bias, skip_input_bias_add_output, - skip_broadcasted, - skip_size); + skip, bias, skip_input_bias_add_output); } #define LAYERNORM_LINEAR_IMPL(T, U, V, simplified) \ template void HostApplyLayerNorm(const cudaDeviceProp& prop, cudaStream_t stream, V* output, \ U* mean, U* inv_std_dev, const T* input, int n1, int n2, \ double epsilon, const V* gamma, const V* beta, const T* skip, \ - const T* bias, T* skip_input_bias_add_output, const bool skip_broadcasted, \ - const int skip_size); + const T* bias, T* skip_input_bias_add_output); LAYERNORM_LINEAR_IMPL(float, float, float, true) LAYERNORM_LINEAR_IMPL(half, float, half, true) diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h index d0d5db8ba3..e3952eefae 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h @@ -43,9 +43,7 @@ void HostApplyLayerNorm( const V* beta, const T* skip = nullptr, const T* bias = nullptr, - T* skip_input_bias_add_output = nullptr, - const bool skip_broadcasted = false, - const int skip_size = 0); + T* skip_input_bias_add_output = nullptr); } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h index 52018500b1..cdb0338157 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h @@ -3,6 +3,9 @@ #pragma once interface IMLOperatorRegistry; +interface IDMLDevice; +interface ID3D12CommandQueue; +interface ID3D12Resource; #include "core/common/status.h" #include "core/framework/data_transfer.h" @@ -28,7 +31,8 @@ namespace Dml std::unique_ptr CreateExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands = true); + bool enableMetacommands, + bool enableDynamicGraphFusion); ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr); void FlushContext(onnxruntime::IExecutionProvider* provider); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index 232a022d86..074f13b309 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -7,11 +7,14 @@ #include #include #include +#include #include "core/framework/op_kernel.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h" struct AbstractOperatorDesc; interface IMLOperatorTensor; +interface IDMLOperator; struct DML_INPUT_GRAPH_EDGE_DESC; struct DML_OUTPUT_GRAPH_EDGE_DESC; struct DML_INTERMEDIATE_GRAPH_EDGE_DESC; @@ -80,7 +83,7 @@ namespace Windows::AI::MachineLearning::Adapter // Either nodesAsOperatorDesc or nodesAsIDMLOperator can have non-zero size. struct DmlGraphNodeCreateInfo { - uint32_t nodeCount; + uint32_t nodeCount = 0; std::vector> nodesAsOperatorDesc; std::vector> nodesAsIDMLOperator; std::vector inputEdges; @@ -92,6 +95,8 @@ namespace Windows::AI::MachineLearning::Adapter const onnxruntime::Node& node, MLOperatorTensorGetter& constantInputGetter, const void* executionHandle, + const EdgeShapes* inputShapesOverrides, + /*out*/ EdgeShapes* outputShapes, /*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo )>; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp index ede3e7f2c2..eb068087de 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp @@ -491,6 +491,8 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( const onnxruntime::Node& node, MLOperatorTensorGetter& constantInputGetter, const void* executionHandle, + const EdgeShapes* inputShapesOverrides, + /*out*/ EdgeShapes* outputShapes, /*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo ) { @@ -498,15 +500,15 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( onnxruntime::OpNodeProtoHelper protoHelper(&nodeContext); // Use the same list of required constant inputs for the shape inferrer and the kernel. - EdgeShapes outputShapes; - InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, nullptr, outputShapes); + InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, inputShapesOverrides, *outputShapes); // Create the kernel while allowing input shape and output shape queries according to options ComPtr kernelInfoWrapper = wil::MakeOrThrow( &protoHelper, executionHandle, true, - &outputShapes, + inputShapesOverrides, + outputShapes, &defaultAttributesCapture, graphNodeCreateInfo, constantCpuInputCapture, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp index 5dbea41901..c24257071e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp @@ -212,15 +212,6 @@ namespace Dml ORT_THROW_HR(E_INVALIDARG); } const auto* allocInfo = static_cast(opaqueHandle); - - auto owner = allocInfo->GetOwner(); - //The owner can be null if the resource was wrapped via CreateGPUAllocationFromD3DResource - if (owner != nullptr && owner != this) - { - // This allocation doesn't belong to this allocator! - ORT_THROW_HR(E_INVALIDARG); - } - return allocInfo; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h index 4c24cb174f..196fba5d76 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h @@ -83,16 +83,16 @@ namespace Dml std::vector m_pool; size_t m_currentAllocationId = 0; uint64_t m_currentResourceId = 0; - - // Unless specifically requested, allocation sizes are not rounded to enable pooling - // until SetDefaultRoundingMode is called. This should be done at completion of session + + // Unless specifically requested, allocation sizes are not rounded to enable pooling + // until SetDefaultRoundingMode is called. This should be done at completion of session // initialization. AllocatorRoundingMode m_defaultRoundingMode = AllocatorRoundingMode::Disabled; std::shared_ptr m_context; std::unique_ptr m_subAllocator; - #if _DEBUG + #ifndef NDEBUG // Useful for debugging; keeps track of all allocations that haven't been freed yet std::map m_outstandingAllocationsById; #endif diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h new file mode 100644 index 0000000000..5ff7049325 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace Windows::AI::MachineLearning::Adapter +{ + // edges and unused edges have an empty array of dimensions. + class EdgeShapes + { + public: + EdgeShapes() = default; + + EdgeShapes(size_t count) : m_shapes(count) {} + + const std::vector& GetShape(size_t edgeIndex) const + { + return m_shapes[edgeIndex]; + } + + std::vector& GetMutableShape(size_t edgeIndex) + { + return m_shapes[edgeIndex]; + } + + size_t EdgeCount() const { return m_shapes.size(); } + + void Reset(size_t edge_count) + { + m_shapes.clear(); + m_shapes.resize(edge_count); + } + + bool operator!=(const EdgeShapes& other) const noexcept + { + return (m_shapes != other.m_shapes); + } + + private: + std::vector> m_shapes; + }; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 51b93efb3a..4f7ec18814 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -1,7 +1,7 @@ #pragma once #include "DmlGraphFusionHelper.h" - +#include "DmlRuntimeFusedGraphKernel.h" namespace Dml { @@ -103,6 +103,36 @@ namespace DmlGraphFusionHelper ORT_THROW_IF_FAILED(resourceUnk->QueryInterface(resource)); } + std::tuple, std::vector, std::byte*, size_t> UnpackInitializer( + const onnxruntime::Graph& graph, + const ONNX_NAMESPACE::TensorProto* initializer) + { + std::unique_ptr unpackedTensor; + std::vector unpackedExternalTensor; + std::byte* tensorPtr = nullptr; + size_t tensorByteSize = 0; + + // The tensor may be stored as raw data or in typed fields. + if (initializer->data_location() == onnx::TensorProto_DataLocation_EXTERNAL) + { + THROW_IF_NOT_OK(onnxruntime::utils::UnpackInitializerData(*initializer, graph.ModelPath(), unpackedExternalTensor)); + tensorPtr = reinterpret_cast(unpackedExternalTensor.data()); + tensorByteSize = unpackedExternalTensor.size(); + } + else if (initializer->has_raw_data()) + { + tensorPtr = (std::byte*)(initializer->raw_data().c_str()); + tensorByteSize = initializer->raw_data().size(); + } + else + { + std::tie(unpackedTensor, tensorByteSize) = Windows::AI::MachineLearning::Adapter::UnpackTensor(*initializer, graph.ModelPath()); + tensorPtr = unpackedTensor.get(); + } + + return std::make_tuple(std::move(unpackedTensor), std::move(unpackedExternalTensor), tensorPtr, tensorByteSize); + } + void ProcessInputData( const ExecutionProviderImpl* providerImpl, const std::vector& isInputsUploadedByDmlEP, @@ -161,32 +191,11 @@ namespace DmlGraphFusionHelper auto iter = initializerNameToInitializerMap.find(subGraphInputArgNames[i]); if (iter != initializerNameToInitializerMap.end()) { - std::byte* tensorPtr = nullptr; - size_t tensorByteSize = 0; - std::vector unpackedExternalTensor; - - std::unique_ptr unpackedTensor; - - //auto& initializer = iter->second; auto* initializer = iter->second.first; + auto [unpackedTensor, unpackedExternalTensor, tensorPtr, tensorByteSize] = UnpackInitializer(graph, initializer); - // The tensor may be stored as raw data or in typed fields. - if (initializer->data_location() == onnx::TensorProto_DataLocation_EXTERNAL) - { - THROW_IF_NOT_OK(onnxruntime::utils::UnpackInitializerData(*initializer, graph.ModelPath(), unpackedExternalTensor)); - tensorPtr = reinterpret_cast(unpackedExternalTensor.data()); - tensorByteSize = unpackedExternalTensor.size(); - } - else if (initializer->has_raw_data()) + if (initializer->data_location() != onnx::TensorProto_DataLocation_EXTERNAL && !initializer->has_raw_data()) { - tensorPtr = (std::byte*)(initializer->raw_data().c_str()); - tensorByteSize = initializer->raw_data().size(); - } - else - { - std::tie(unpackedTensor, tensorByteSize) = Windows::AI::MachineLearning::Adapter::UnpackTensor(*initializer, graph.ModelPath()); - tensorPtr = unpackedTensor.get(); - // Free the initializer if this is the last usage of it. if (initializerToLastInputIndexMap[initializer] == i) { @@ -501,5 +510,173 @@ namespace DmlGraphFusionHelper graph.FinalizeFuseSubGraph(indexedSubGraph, fusedNode); } + + void RegisterDynamicKernel( + onnxruntime::Graph& graph, + onnxruntime::KernelRegistry* registryForPartitionKernels, + const ExecutionProviderImpl* providerImpl, + std::unordered_map graphNodePropertyMap, + const std::unordered_set& dynamicCpuInputMap, + std::shared_ptr indexedSubGraph, + std::unordered_map>&& isInitializerTransferable) + { + struct NodeInfo + { + std::string name; + std::string opType; + std::string description; + std::string domain; + onnxruntime::NodeAttributes attributes; + std::vector inputDefPointers; + std::vector outputDefPointers; + }; + + auto partitionNodePropsMap = DmlGraphFusionHelper::CreatePartitionNodePropsMap( + graph, + *indexedSubGraph, + std::move(graphNodePropertyMap)); + + auto modelPath = graph.ModelPath(); + + const gsl::span subGraphInputArgNames = indexedSubGraph->GetMetaDef()->inputs; + const gsl::span subGraphOutputArgNames = indexedSubGraph->GetMetaDef()->outputs; + + std::vector nodesInfo; + nodesInfo.reserve(indexedSubGraph->nodes.size()); + + std::vector subgraphInputs; + subgraphInputs.reserve(subGraphInputArgNames.size()); + + std::vector subgraphOutputs; + subgraphOutputs.reserve(subGraphOutputArgNames.size()); + + std::vector nodeAttributes; + nodeAttributes.reserve(indexedSubGraph->nodes.size()); + + std::vector> intermediateNodeArgs; + + for (size_t sortedNodeIndex : indexedSubGraph->nodes) + { + auto node = graph.GetNode(sortedNodeIndex); + + nodeAttributes.push_back(node->GetAttributes()); + + NodeInfo nodeInfo{}; + nodeInfo.name = node->Name(); + nodeInfo.opType = node->OpType(); + nodeInfo.description = node->Description(); + nodeInfo.domain = node->Domain(); + nodeInfo.attributes = node->GetAttributes(); + nodeInfo.inputDefPointers.reserve(node->InputDefs().size()); + nodeInfo.outputDefPointers.reserve(node->OutputDefs().size()); + + for (const onnxruntime::NodeArg* inputDef : node->InputDefs()) + { + intermediateNodeArgs.emplace_back(std::make_shared(inputDef->Name(), inputDef->TypeAsProto())); + nodeInfo.inputDefPointers.push_back(intermediateNodeArgs.back().get()); + } + + for (const onnxruntime::NodeArg* outputDef : node->OutputDefs()) + { + intermediateNodeArgs.emplace_back(std::make_shared(outputDef->Name(), outputDef->TypeAsProto())); + nodeInfo.outputDefPointers.push_back(intermediateNodeArgs.back().get()); + } + + nodesInfo.push_back(std::move(nodeInfo)); + } + + for (const std::string& graphInputName : subGraphInputArgNames) + { + subgraphInputs.push_back(graph.GetNodeArg(graphInputName)); + } + + for (const std::string& graphOutputName : subGraphOutputArgNames) + { + subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName)); + } + + // We need to keep the initializers alive since they will be freed once the nodes are removed from the graph + std::vector ownedInitializers; + ownedInitializers.reserve(isInitializerTransferable.size()); + + for (auto& kvp : isInitializerTransferable) + { + auto [unpackedTensor, unpackedExternalTensor, tensorPtr, tensorByteSize] = UnpackInitializer(graph, kvp.second.first); + + ONNX_NAMESPACE::TensorProto tensorProto; + tensorProto.set_data_type(kvp.second.first->data_type()); + tensorProto.set_raw_data(tensorPtr, tensorByteSize); + tensorProto.set_name(kvp.second.first->name()); + + for (int i = 0; i < kvp.second.first->dims_size(); ++i) + { + tensorProto.add_dims(kvp.second.first->dims(i)); + } + ownedInitializers.push_back(std::move(tensorProto)); + kvp.second.first = &ownedInitializers.back(); + } + + // lamda captures for the kernel registration + auto fused_kernel_func = [ + indexedSubGraph, + &modelPath, + nodesInfo = std::move(nodesInfo), + intermediateNodeArgs = std::move(intermediateNodeArgs), + subgraphInputs = std::move(subgraphInputs), + subgraphOutputs = std::move(subgraphOutputs), + partitionNodePropsMap = std::move(partitionNodePropsMap), + ownedInitializers = std::move(ownedInitializers)] (onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr& out) mutable ->onnxruntime::Status + { + std::vector> subgraphNodes; + subgraphNodes.reserve(nodesInfo.size()); + + for (const NodeInfo& nodeInfo : nodesInfo) + { + subgraphNodes.emplace_back(std::make_shared( + nodeInfo.name, + nodeInfo.opType, + nodeInfo.description, + nodeInfo.inputDefPointers, + nodeInfo.outputDefPointers, + &nodeInfo.attributes, + nodeInfo.domain)); + } + + out.reset(CreateRuntimeFusedGraphKernel( + info, + indexedSubGraph, + modelPath, + std::move(subgraphNodes), + std::move(subgraphInputs), + std::move(subgraphOutputs), + std::move(intermediateNodeArgs), + std::move(partitionNodePropsMap), + std::move(ownedInitializers))); + return Status::OK(); + }; + + // build the kernel definition on the fly, and register it to the fused_kernel_regisitry. + onnxruntime::KernelDefBuilder builder; + builder.SetName(indexedSubGraph->GetMetaDef()->name) + .SetDomain(indexedSubGraph->GetMetaDef()->domain) + .SinceVersion(indexedSubGraph->GetMetaDef()->since_version) + .Provider(onnxruntime::kDmlExecutionProvider); + + // Force the CPU inputs to be allocated on the CPU + for (int i = 0; i < subGraphInputArgNames.size(); ++i) + { + if (dynamicCpuInputMap.find(subGraphInputArgNames[i]) != dynamicCpuInputMap.end()) + { + builder.InputMemoryType(OrtMemTypeCPUInput, i); + } + } + + ORT_THROW_IF_ERROR(registryForPartitionKernels->Register(builder, fused_kernel_func)); + + auto& fusedNode = graph.BeginFuseSubGraph(*indexedSubGraph, indexedSubGraph->GetMetaDef()->name); + fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider); + + graph.FinalizeFuseSubGraph(*indexedSubGraph, fusedNode); + } } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h index 030cffc2a8..f8f6162aaa 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h @@ -80,5 +80,14 @@ namespace DmlGraphFusionHelper std::vector&& isInputsUploadedByDmlEP, const GraphDescBuilder::GraphDesc& graphDesc, Microsoft::WRL::ComPtr compiledExecutionPlanOperator); + + void RegisterDynamicKernel( + onnxruntime::Graph& graph, + onnxruntime::KernelRegistry* registryForPartitionKernels, + const ExecutionProviderImpl* providerImpl, + std::unordered_map graphNodePropertyMap, + const std::unordered_set& dynamicCpuInputMap, + std::shared_ptr indexedSubGraph, + std::unordered_map>&& isInitializerTransferable); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp index a9d19a022d..679738b639 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp @@ -15,6 +15,18 @@ namespace Dml { + namespace + { + struct CompiledPartitionInfo + { + Microsoft::WRL::ComPtr compiledOperator; + onnxruntime::IndexedSubGraph indexedSubGraph; + std::vector isInputsUploadedByDmlEP; + GraphDescBuilder::GraphDesc graphDesc; + std::unordered_map> isInitializerTransferable; + }; + } + DmlGraphFusionTransformer::DmlGraphFusionTransformer( const std::string& name, const onnxruntime::IExecutionProvider* provider @@ -24,20 +36,21 @@ namespace Dml { } - struct CompiledPartitionInfo - { - Microsoft::WRL::ComPtr compiledOperator; - onnxruntime::IndexedSubGraph indexedSubGraph; - std::vector isInputsUploadedByDmlEP; - GraphDescBuilder::GraphDesc graphDesc; - std::unordered_map> isInitializerTransferable; - }; - onnxruntime::common::Status DmlGraphFusionTransformer::ApplyImpl( onnxruntime::Graph& graph, bool& modified, int graph_level, const onnxruntime::logging::Logger& logger) const + { + return ApplyImplHelper(graph, modified, graph_level, logger, {}); + } + + onnxruntime::common::Status DmlGraphFusionTransformer::ApplyImplHelper( + onnxruntime::Graph& graph, + bool& modified, + int graph_level, + const onnxruntime::logging::Logger& logger, + const std::unordered_map& implicitInputDefs) const { onnxruntime::ProviderType provider_type = onnxruntime::kDmlExecutionProvider; const gsl::not_null registry = m_providerImpl->GetKernelRegistry().get(); @@ -49,10 +62,35 @@ namespace Dml std::vector> compiledPartitionInfos; std::vector additionalSplittingNodes; + onnxruntime::GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (auto node_index : node_topology_list) + { + auto* node = graph.GetNode(node_index); + if (!node) + { + continue; // node was removed + } + + std::unordered_map subgraphImplicitInputDefs; + for (const onnxruntime::NodeArg* inputDef : node->ImplicitInputDefs()) + { + subgraphImplicitInputDefs[inputDef->Name()] = inputDef; + } + + for (auto& entry : node->GetAttributeNameToMutableSubgraphMap()) + { + auto& subgraph = *entry.second; + ORT_RETURN_IF_ERROR(ApplyImplHelper(subgraph, modified, graph_level + 1, logger, subgraphImplicitInputDefs)); + } + } + do { // Initializers needed by any graph partition std::unordered_set requiredInitializerMap; + std::unordered_set dynamicCpuInputMap; std::unordered_map graphNodePropertyMap; onnxruntime::GraphViewer graphViewer(graph); std::vector> partitions = BuildPartitions( @@ -62,7 +100,10 @@ namespace Dml m_providerImpl->GetSupportedDeviceDataTypeMask(), graphNodePropertyMap, requiredInitializerMap, - additionalSplittingNodes); + dynamicCpuInputMap, + additionalSplittingNodes, + implicitInputDefs, + false); // Reset the splitting nodes for the current iteration additionalSplittingNodes.clear(); @@ -155,17 +196,48 @@ namespace Dml std::move(graphNodePropertyMap)); // Convert partitionONNXGraph into DML EP GraphDesc + auto modelPath = graph.ModelPath(); + + const gsl::span subGraphInputArgNames = indexedSubGraph.GetMetaDef()->inputs; + const gsl::span subGraphOutputArgNames = indexedSubGraph.GetMetaDef()->outputs; + + std::vector subgraphNodes; + subgraphNodes.reserve(indexedSubGraph.nodes.size()); + + std::vector subgraphInputs; + subgraphInputs.reserve(subGraphInputArgNames.size()); + + std::vector subgraphOutputs; + subgraphOutputs.reserve(subGraphOutputArgNames.size()); + + for (size_t sortedNodeIndex : indexedSubGraph.nodes) + { + subgraphNodes.push_back(graph.GetNode(sortedNodeIndex)); + } + + for (const std::string& graphInputName : subGraphInputArgNames) + { + subgraphInputs.push_back(graph.GetNodeArg(graphInputName)); + } + + for (const std::string& graphOutputName : subGraphOutputArgNames) + { + subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName)); + } + ComPtr device; ORT_THROW_IF_FAILED(m_providerImpl->GetDmlDevice(device.GetAddressOf())); GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( isInputsUploadedByDmlEP.data(), isInputsUploadedByDmlEP.size(), isInitializerTransferable, - graph, - indexedSubGraph, partitionNodePropsMap, device.Get(), - m_providerImpl); + m_providerImpl, + modelPath, + subgraphNodes, + subgraphInputs, + subgraphOutputs); // Compile the operator auto compiledPartition = DmlGraphFusionHelper::TryCreateCompiledOperator( diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h index b546f29f59..19dab0c899 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h @@ -2,32 +2,41 @@ // Licensed under the MIT License. #pragma once - +#include +#include #include "core/optimizer/graph_transformer.h" #include "core/framework/execution_providers.h" namespace Dml { - class ExecutionProviderImpl; +class ExecutionProviderImpl; + +class DmlGraphFusionTransformer : public onnxruntime::GraphTransformer +{ +public: + DmlGraphFusionTransformer( + const std::string& name, + const onnxruntime::IExecutionProvider* provider + ); + +public: + static inline const char* const DML_GRAPH_FUSION_NODE_NAME_PREFIX = "DmlFusedNode_"; + static inline const char* const DML_GRAPH_FUSION_NODE_DOMAIN = "DmlFusedNodeDomain"; - class DmlGraphFusionTransformer : public onnxruntime::GraphTransformer - { - public: - DmlGraphFusionTransformer( - const std::string& name, - const onnxruntime::IExecutionProvider* provider - ); +private: + onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, + bool& modified, + int graph_level, + const onnxruntime::logging::Logger& logger) const final; - public: - inline const static char* const DML_GRAPH_FUSION_NODE_NAME_PREFIX = "DmlFusedNode_"; - inline const static char* const DML_GRAPH_FUSION_NODE_DOMAIN = "DmlFusedNodeDomain"; + onnxruntime::common::Status ApplyImplHelper( + onnxruntime::Graph& graph, + bool& modified, + int graph_level, + const onnxruntime::logging::Logger& logger, + const std::unordered_map& implicitInputDefs) const; - private: - onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, - bool& modified, - int graph_level, - const onnxruntime::logging::Logger& logger) const final; - private: - const ExecutionProviderImpl* m_providerImpl = nullptr; - }; +private: + const ExecutionProviderImpl* m_providerImpl = nullptr; +}; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp new file mode 100644 index 0000000000..1db22ac92e --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -0,0 +1,369 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +#include "core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h" + +using namespace Windows::AI::MachineLearning::Adapter; + +namespace Dml +{ + class DmlRuntimeFusedGraphKernel : public onnxruntime::OpKernel + { + public: + DmlRuntimeFusedGraphKernel() = delete; + + DmlRuntimeFusedGraphKernel( + const onnxruntime::OpKernelInfo& kernelInfo, + std::shared_ptr indexedSubGraph, + const onnxruntime::Path& modelPath, + std::vector>&& subgraphNodes, + std::vector&& subgraphInputs, + std::vector&& subgraphOutputs, + std::vector>&& intermediateNodeArgs, + std::unordered_map&& partitionNodePropsMap, + std::vector&& ownedInitializers) + : OpKernel(kernelInfo), + m_indexedSubGraph(std::move(indexedSubGraph)), + m_modelPath(modelPath), + m_subgraphNodes(std::move(subgraphNodes)), + m_subgraphInputs(std::move(subgraphInputs)), + m_subgraphOutputs(std::move(subgraphOutputs)), + m_intermediateNodeArgs(std::move(intermediateNodeArgs)), + m_partitionNodePropsMap(std::move(partitionNodePropsMap)), + m_ownedInitializers(std::move(ownedInitializers)) + { + for (const auto& initializer : m_ownedInitializers) + { + m_isInitializerTransferable[initializer.name()] = std::make_pair(&initializer, false); + } + + // Get the execution provider interfaces + auto executionHandle = kernelInfo.GetExecutionProvider()->GetExecutionHandle(); + if (executionHandle) + { + // We assume the execution object inherits IUnknown as its first base + ComPtr providerExecutionObject = const_cast(static_cast(executionHandle)); + + // Get the WinML-specific execution provider interface from the execution object. + ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_provider)); + ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_winmlProvider)); + } + + m_subgraphNodePointers.reserve(m_subgraphNodes.size()); + + for (auto& subgraphNode : m_subgraphNodes) + { + m_subgraphNodePointers.push_back(subgraphNode.get()); + } + } + + void TranslateAndCompileGraph( + const onnxruntime::OpKernelInfo& kernelInfo, + std::vector>& initializeResourceRefs, + std::vector initInputBindings) const + { + // Allocate a persistent resource and initialize the operator + UINT64 persistentResourceSize = m_compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize; + if (persistentResourceSize > 0) + { + ORT_THROW_IF_FAILED(m_provider->AllocatePooledResource( + static_cast(persistentResourceSize), + AllocatorRoundingMode::Disabled, + m_persistentResource.ReleaseAndGetAddressOf(), + m_persistentResourceAllocatorUnk.ReleaseAndGetAddressOf())); + + m_persistentResourceBinding = DML_BUFFER_BINDING { m_persistentResource.Get(), 0, persistentResourceSize }; + } + + ORT_THROW_IF_FAILED(m_provider->InitializeOperator( + m_compiledExecutionPlanOperator.Get(), + m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, + gsl::make_span(initInputBindings))); + + std::for_each( + initializeResourceRefs.begin(), + initializeResourceRefs.end(), + [&](ComPtr& resource){ m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(resource).Get()); } + ); + } + + onnxruntime::Status Compute(onnxruntime::OpKernelContext* kernelContext) const override + { + ORT_THROW_HR_IF(E_UNEXPECTED, m_subgraphInputs.size() != kernelContext->InputCount()); + + bool recompileNeeded = m_compiledExecutionPlanOperator == nullptr; + + for (int inputIndex = 0; inputIndex < kernelContext->InputCount(); ++inputIndex) + { + const auto& input = kernelContext->RequiredInput(inputIndex); + const std::string& inputName = m_subgraphInputs[inputIndex]->Name(); + auto shapeIter = m_inferredInputShapes.find(inputName); + + if (shapeIter == m_inferredInputShapes.end()) + { + m_inferredInputShapes[inputName] = input.Shape(); + recompileNeeded = true; + } + else if (shapeIter->second != input.Shape()) + { + shapeIter->second = input.Shape(); + recompileNeeded = true; + } + + // If we have CPU inputs that are not initializers (i.e. they were computed at runtime), add them to the initializer list + if (input.Location().device.Type() == OrtDevice::CPU) + { + auto inputProto = onnxruntime::utils::TensorToTensorProto(input, inputName); + + // We can only avoid recompiling the graph when all CPU inputs are identical + auto initializerIter = m_isInitializerTransferable.find(inputName); + + if (initializerIter != m_isInitializerTransferable.end()) + { + if (initializerIter->second.first->raw_data().length() == inputProto.raw_data().length()) + { + for (int i = 0; i < inputProto.raw_data().length(); ++i) + { + if (initializerIter->second.first->raw_data()[i] != inputProto.raw_data()[i]) + { + recompileNeeded = true; + break; + } + } + } + else + { + recompileNeeded = true; + } + } + else + { + recompileNeeded = true; + } + + m_ownedCpuInputs.push_back(std::make_unique(std::move(inputProto))); + m_isInitializerTransferable[inputName] = std::make_pair(m_ownedCpuInputs.back().get(), false); + } + } + + if (recompileNeeded) + { + // Go through all the node args and replace their shapes with the real ones + for (auto& nodeArg : m_intermediateNodeArgs) + { + auto iter = m_inferredInputShapes.find(nodeArg->Name()); + if (iter != m_inferredInputShapes.end()) + { + auto tensorShape = *nodeArg->Shape(); + ORT_THROW_HR_IF(E_UNEXPECTED, tensorShape.dim_size() != iter->second.NumDimensions()); + + for (int i = 0; i < tensorShape.dim_size(); ++i) + { + tensorShape.mutable_dim(i)->set_dim_value(iter->second.GetDims()[i]); + } + + nodeArg->SetShape(tensorShape); + } + } + + // Populate input bindings for operator initialization + const uint32_t fusedNodeInputCount = gsl::narrow_cast(m_indexedSubGraph->GetMetaDef()->inputs.size()); + std::vector> initializeResourceRefs; // For lifetime control + std::vector initInputBindings(fusedNodeInputCount); + std::vector isInputsUploadedByDmlEP(fusedNodeInputCount); + auto providerImpl = static_cast(Info().GetExecutionProvider())->GetImpl(); + + // Convert partitionONNXGraph into DML EP GraphDesc + ComPtr device; + ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); + GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( + isInputsUploadedByDmlEP.data(), + isInputsUploadedByDmlEP.size(), + m_isInitializerTransferable, + m_partitionNodePropsMap, + device.Get(), + providerImpl, + m_modelPath, + m_subgraphNodePointers, + m_subgraphInputs, + m_subgraphOutputs); + + m_outputShapes = graphDesc.outputShapes; + + // Walk through each graph edge and mark used inputs + m_inputsUsed.resize(fusedNodeInputCount, false); + for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges) + { + m_inputsUsed[edge.GraphInputIndex] = true; + } + + // Compile the operator + m_compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator( + graphDesc, + *m_indexedSubGraph, + providerImpl); + + // Queue references to objects which must be kept alive until resulting GPU work completes + m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get()); + + TranslateAndCompileGraph( + Info(), + initializeResourceRefs, + initInputBindings); + } + + // Wrap tensors as required by Dml::IExecutionProvider::ExecuteOperator + OpKernelContextWrapper contextWrapper( + kernelContext, + Info().GetExecutionProvider(), + true, + nullptr); + + ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); + + // Get input resources for execution, excluding those which were specified as owned by DML and provided + // at initialization instead. + std::vector> inputTensors(kernelContext->InputCount()); + std::vector inputPtrs(kernelContext->InputCount()); + + for (int i = 0; i < kernelContext->InputCount(); ++i) + { + if (!m_inputsUsed[i]) + { + continue; + } + + ORT_THROW_IF_FAILED(contextWrapper.GetInputTensor(i, inputTensors[i].GetAddressOf())); + inputPtrs[i] = m_provider->DecodeResource(MLOperatorTensor(inputTensors[i].Get()).GetDataInterface().Get()); + } + + auto outputTensors = contextWrapper.GetOutputTensors(m_outputShapes); + ExecuteOperator( + m_compiledExecutionPlanOperator.Get(), + m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, + inputPtrs, + outputTensors); + + ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); + + return onnxruntime::Status::OK(); + } + + void ExecuteOperator( + IDMLCompiledOperator* op, + _In_opt_ const DML_BUFFER_BINDING* persistentResourceBinding, + gsl::span inputTensors, + gsl::span outputTensors) const + { + auto FillBindingsFromTensors = [this](auto& bufferBindings, auto& bindingDescs, gsl::span& tensors) + { + for (IMLOperatorTensor* tensor : tensors) + { + if (tensor) + { + assert(tensor->IsDataInterface()); + ID3D12Resource* resource = m_provider->DecodeResource(MLOperatorTensor(tensor).GetDataInterface().Get()); + D3D12_RESOURCE_DESC resourceDesc = resource->GetDesc(); + bufferBindings.push_back({ resource, 0, resourceDesc.Width }); + bindingDescs.push_back({ DML_BINDING_TYPE_BUFFER, &bufferBindings.back() }); + } + else + { + bufferBindings.push_back({ nullptr, 0, 0 }); + bindingDescs.push_back({ DML_BINDING_TYPE_NONE, nullptr }); + } + } + }; + + auto FillBindingsFromBuffers = [](auto& bufferBindings, auto& bindingDescs, gsl::span& resources) + { + for (ID3D12Resource* resource : resources) + { + if (resource) + { + D3D12_RESOURCE_DESC resourceDesc = resource->GetDesc(); + bufferBindings.push_back({ resource, 0, resourceDesc.Width }); + bindingDescs.push_back({ DML_BINDING_TYPE_BUFFER, &bufferBindings.back() }); + } + else + { + bufferBindings.push_back({ nullptr, 0, 0 }); + bindingDescs.push_back({ DML_BINDING_TYPE_NONE, nullptr }); + } + } + }; + + std::vector inputBufferBindings; + inputBufferBindings.reserve(inputTensors.size()); + std::vector inputBindings; + inputBindings.reserve(inputTensors.size()); + FillBindingsFromBuffers(inputBufferBindings, inputBindings, inputTensors); + + std::vector outputBufferBindings; + outputBufferBindings.reserve(outputTensors.size()); + std::vector outputBindings; + outputBindings.reserve(outputTensors.size()); + FillBindingsFromTensors(outputBufferBindings, outputBindings, outputTensors); + + ORT_THROW_IF_FAILED(m_provider->ExecuteOperator( + op, + persistentResourceBinding, + inputBindings, + outputBindings)); + } + + private: + ComPtr m_winmlProvider; + ComPtr m_provider; + + mutable std::optional m_persistentResourceBinding; + std::shared_ptr m_indexedSubGraph; + const onnxruntime::Path& m_modelPath; + + std::vector> m_subgraphNodes; + std::vector m_subgraphInputs; + std::vector m_subgraphOutputs; + mutable std::vector> m_intermediateNodeArgs; + std::unordered_map m_partitionNodePropsMap; + std::vector m_ownedInitializers; + mutable std::unordered_map> m_isInitializerTransferable; + std::vector m_subgraphNodePointers; + + // Bindings from previous executions of a re-used command list + mutable std::vector> m_ownedCpuInputs; + mutable ComPtr m_compiledExecutionPlanOperator; + mutable std::vector m_inputsUsed; + mutable ComPtr m_persistentResource; + mutable ComPtr m_persistentResourceAllocatorUnk; // Controls when the persistent resource is returned to the allocator + mutable Windows::AI::MachineLearning::Adapter::EdgeShapes m_outputShapes; + mutable std::unordered_map m_inferredInputShapes; + }; + + onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel( + const onnxruntime::OpKernelInfo& info, + std::shared_ptr indexedSubGraph, + const onnxruntime::Path& modelPath, + std::vector>&& subgraphNodes, + std::vector&& subgraphInputs, + std::vector&& subgraphOutputs, + std::vector>&& intermediateNodeArgs, + std::unordered_map&& partitionNodePropsMap, + std::vector&& ownedInitializers) + { + return new DmlRuntimeFusedGraphKernel( + info, + std::move(indexedSubGraph), + modelPath, + std::move(subgraphNodes), + std::move(subgraphInputs), + std::move(subgraphOutputs), + std::move(intermediateNodeArgs), + std::move(partitionNodePropsMap), + std::move(ownedInitializers) + ); + } +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h new file mode 100644 index 0000000000..d679c5aa56 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/op_kernel.h" +#include "GraphDescBuilder.h" +#include "DmlRuntimeGraphFusionTransformer.h" + +namespace Dml +{ + onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel( + const onnxruntime::OpKernelInfo& info, + std::shared_ptr indexedSubGraph, + const onnxruntime::Path& modelPath, + std::vector>&& subgraphNodes, + std::vector&& subgraphInputs, + std::vector&& subgraphOutputs, + std::vector>&& intermediateNodeArgs, + std::unordered_map&& partitionNodePropsMap, + std::vector&& ownedInitializers + ); +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp new file mode 100644 index 0000000000..6318b0d5e2 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp @@ -0,0 +1,161 @@ +#pragma once + +#include "precomp.h" +#include "GraphDescBuilder.h" +#include "ExecutionProvider.h" +#include "DmlRuntimeGraphFusionTransformer.h" +#include "GraphPartitioner.h" +#include "core/framework/kernel_type_str_resolver.h" +#include "core/framework/kernel_lookup.h" +#include "core/optimizer/constant_sharing.h" +#include "DmlRuntimeFusedGraphKernel.h" +#include "MLOperatorAuthorImpl.h" +#include "DmlGraphFusionHelper.h" + +namespace Dml +{ + namespace + { + struct CompiledPartitionInfo + { + std::shared_ptr indexedSubGraph; + std::unordered_map> isInitializerTransferable; + }; + } + + DmlRuntimeGraphFusionTransformer::DmlRuntimeGraphFusionTransformer( + const std::string& name, + const onnxruntime::IExecutionProvider* provider + ) + :onnxruntime::GraphTransformer(name), + m_providerImpl(static_cast(provider)->GetImpl()) + { + } + + onnxruntime::common::Status DmlRuntimeGraphFusionTransformer::ApplyImpl( + onnxruntime::Graph& graph, + bool& modified, + int graphLevel, + const onnxruntime::logging::Logger& logger) const + { + return ApplyImplHelper(graph, modified, graphLevel, logger, {}); + } + + onnxruntime::common::Status DmlRuntimeGraphFusionTransformer::ApplyImplHelper( + onnxruntime::Graph& graph, + bool& modified, + int graphLevel, + const onnxruntime::logging::Logger& logger, + const std::unordered_map& implicitInputDefs) const + { + onnxruntime::ProviderType providerType = onnxruntime::kDmlExecutionProvider; + const gsl::not_null registry = m_providerImpl->GetKernelRegistry().get(); + const auto kernelTypeStrResolver = onnxruntime::OpSchemaKernelTypeStrResolver{}; + const auto kernelLookup = onnxruntime::KernelLookup( + providerType, + gsl::make_span(®istry, 1), + kernelTypeStrResolver); + + onnxruntime::GraphViewer graphViewer(graph); + const auto& nodeTopologyList = graphViewer.GetNodesInTopologicalOrder(); + + for (auto nodeIndex : nodeTopologyList) + { + auto* node = graph.GetNode(nodeIndex); + if (!node) + { + continue; // node was removed + } + + std::unordered_map subgraphImplicitInputDefs; + for (const onnxruntime::NodeArg* inputDef : node->ImplicitInputDefs()) + { + subgraphImplicitInputDefs[inputDef->Name()] = inputDef; + } + + for (auto& entry : node->GetAttributeNameToMutableSubgraphMap()) + { + auto& subgraph = *entry.second; + ORT_RETURN_IF_ERROR(ApplyImplHelper(subgraph, modified, graphLevel + 1, logger, subgraphImplicitInputDefs)); + } + } + + // Initializers needed by any graph partition + std::vector additionalSplittingNodes; + std::unordered_map graphNodePropertyMap; + std::unordered_set requiredInitializerMap; + std::unordered_set dynamicCpuInputMap; + std::vector> partitions = BuildPartitions( + graphViewer, + *m_providerImpl->GetInternalRegistrationInfoMap(), + kernelLookup, + m_providerImpl->GetSupportedDeviceDataTypeMask(), + graphNodePropertyMap, + requiredInitializerMap, + dynamicCpuInputMap, + additionalSplittingNodes, + implicitInputDefs, + true); + + // Reset the splitting nodes for the current iteration + additionalSplittingNodes.clear(); + + // Reset the compiled operators for the current iteration + std::vector> compiledPartitionInfos(partitions.size()); + + // Create a map between each initialized tensor and the partition(s) it is part of. + auto initializerPartitionMap = DmlGraphFusionHelper::GetInitializerToPartitionMap(graphViewer, partitions); + + for (uint32_t partitionIndex = 0; partitionIndex < partitions.size(); ++partitionIndex) + { + auto& partition = partitions[partitionIndex]; + + if (partition->GetRootMergedPartition() != partition.get() || + !partition->IsDmlPartition()) + { + continue; + } + + if (partition->IsDmlGraphPartition()) + { + std::unordered_map> isInitializerTransferable; + + std::string partitionKernelPrefix = std::to_string(m_providerImpl->GetPartitionKernelPrefixVal()) + "_"; + m_providerImpl->IncreasePartitionKernelPrefixVal(); + + // populate isInitializerTransferable + for (const auto& input : partition->GetInputs()) + { + const onnx::TensorProto* tensor = nullptr; + if (graph.GetInitializedTensor(input, tensor) && requiredInitializerMap.find(input) != requiredInitializerMap.end()) + { + isInitializerTransferable[input] = {tensor, false}; + } + } + + compiledPartitionInfos[partitionIndex] = std::make_shared(); + compiledPartitionInfos[partitionIndex]->indexedSubGraph = std::make_shared( + DmlGraphFusionHelper::CreateIndexedSubGraph(partition.get(), partitionIndex, partitionKernelPrefix)); + compiledPartitionInfos[partitionIndex]->isInitializerTransferable = std::move(isInitializerTransferable); + } + } + + for (auto&& compiledPartitionInfo : compiledPartitionInfos) + { + // Null compiled operators were not DML partitions + if (compiledPartitionInfo) + { + DmlGraphFusionHelper::RegisterDynamicKernel( + graph, + m_providerImpl->GetKernelRegistry().get(), + m_providerImpl, + graphNodePropertyMap, + dynamicCpuInputMap, + std::move(compiledPartitionInfo->indexedSubGraph), + std::move(compiledPartitionInfo->isInitializerTransferable)); + } + } + + return onnxruntime::common::Status::OK(); + } +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h new file mode 100644 index 0000000000..cfa743e1f2 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include "core/optimizer/graph_transformer.h" +#include "core/framework/execution_providers.h" + +namespace Dml +{ +class ExecutionProviderImpl; + +class DmlRuntimeGraphFusionTransformer : public onnxruntime::GraphTransformer +{ +public: + DmlRuntimeGraphFusionTransformer( + const std::string& name, + const onnxruntime::IExecutionProvider* provider + ); + +public: + static inline const char* const DML_GRAPH_FUSION_NODE_NAME_PREFIX = "DmlRuntimeFusedNode_"; + static inline const char* const DML_GRAPH_FUSION_NODE_DOMAIN = "DmlRuntimeFusedNodeDomain"; + +private: + onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, + bool& modified, + int graphLevel, + const onnxruntime::logging::Logger& logger) const final; + + onnxruntime::common::Status ApplyImplHelper( + onnxruntime::Graph& graph, + bool& modified, + int graphLevel, + const onnxruntime::logging::Logger& logger, + const std::unordered_map& implicitInputDefs) const; + +private: + const ExecutionProviderImpl* m_providerImpl = nullptr; +}; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index f97b72aa2d..277da1591b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -67,7 +67,8 @@ namespace Dml ExecutionProvider::ExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands) : + bool enableMetacommands, + bool enableDynamicGraphFusion) : IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)) { D3D12_COMMAND_LIST_TYPE queueType = commandQueue->GetDesc().Type; @@ -80,7 +81,7 @@ namespace Dml ComPtr device; GRAPHICS_THROW_IF_FAILED(commandQueue->GetDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf()))); - m_impl = wil::MakeOrThrow(dmlDevice, device.Get(), commandQueue, enableMetacommands); + m_impl = wil::MakeOrThrow(dmlDevice, device.Get(), commandQueue, enableMetacommands, enableDynamicGraphFusion); } std::vector> @@ -147,12 +148,12 @@ namespace Dml // Task 24384515: Update ORT AIInfra release agent pool to install 19H1 SDK on VM bootstrap #define D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE ((D3D_FEATURE_LEVEL)0x1000) - ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, bool enableMetacommands) + ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, bool enableMetacommands, bool enableDynamicGraphFusion) : m_d3d12Device(d3d12Device), m_dmlDevice(dmlDevice), - m_areMetacommandsEnabled(enableMetacommands) + m_areMetacommandsEnabled(enableMetacommands), + m_dynamicGraphFusionEnabled(enableDynamicGraphFusion) { - D3D12_FEATURE_DATA_FEATURE_LEVELS featureLevels = {}; D3D_FEATURE_LEVEL featureLevelsList[] = { @@ -1093,6 +1094,11 @@ namespace Dml return m_areMetacommandsEnabled; } + bool ExecutionProviderImpl::DynamicGraphFusionEnabled() const noexcept + { + return m_dynamicGraphFusionEnabled; + } + std::shared_ptr ExecutionProviderImpl::GetInternalRegistrationInfoMap() const { @@ -1129,9 +1135,10 @@ namespace Dml std::unique_ptr CreateExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands) + bool enableMetacommands, + bool enableDynamicGraphFusion) { - return std::make_unique(dmlDevice, commandQueue, enableMetacommands); + return std::make_unique(dmlDevice, commandQueue, enableMetacommands, enableDynamicGraphFusion); } ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 31b893a2f2..3aaa11cdee 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -5,6 +5,7 @@ #include "GraphTransformer.h" #include "core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h" +#include "core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h" #include #include @@ -34,7 +35,8 @@ namespace Dml IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, - bool enableMetacommands = true); + bool enableMetacommands, + bool enableDynamicGraphFusion); void ReleaseCompletedReferences(); @@ -150,6 +152,7 @@ namespace Dml STDMETHOD_(bool, IsMcdmDevice)() const noexcept final; STDMETHOD_(bool, MetacommandsEnabled)() const noexcept final; + bool DynamicGraphFusionEnabled() const noexcept; std::shared_ptr GetGpuAllocator(); std::shared_ptr GetCpuInputAllocator(); @@ -184,6 +187,7 @@ namespace Dml ComPtr m_dmlDevice; bool m_isMcdmDevice = false; bool m_areMetacommandsEnabled = true; + bool m_dynamicGraphFusionEnabled = false; bool m_native16BitShaderOpsSupported = false; std::shared_ptr m_context; std::unique_ptr m_uploadHeap; @@ -236,7 +240,8 @@ namespace Dml explicit ExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands = true + bool enableMetacommands, + bool enableDynamicGraphFusion ); std::unique_ptr GetDataTransfer() const final override @@ -299,9 +304,9 @@ namespace Dml return m_impl.Get(); } - void MetacommandsEnabled() + bool DynamicGraphFusionEnabled() const { - m_impl->MetacommandsEnabled(); + return m_impl->DynamicGraphFusionEnabled(); } virtual std::vector CreatePreferredAllocators() override diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 636f46428c..3fc8f415e5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -147,14 +147,14 @@ namespace Dml::GraphDescBuilder const uint8_t* isConstGpuGraphInput, const size_t isConstGpuGraphInputCount, const std::unordered_map>& isInitializerTransferable, - const onnxruntime::Graph& graph, - const onnxruntime::IndexedSubGraph& indexedSubGraph, const std::unordered_map& graphNodePropertyMap, IDMLDevice* device, - const void* executionHandle) + const void* executionHandle, + const onnxruntime::Path& modelPath, + gsl::span subgraphNodes, + gsl::span subgraphInputs, + gsl::span subgraphOutputs) { - const gsl::span subGraphInputArgNames = indexedSubGraph.GetMetaDef()->inputs; - const gsl::span subGraphOutputArgNames = indexedSubGraph.GetMetaDef()->outputs; struct NodeAndIndex { uint32_t nodeIndex; // The index of the node itself @@ -164,12 +164,14 @@ namespace Dml::GraphDescBuilder // Map from Lotus node argument names to the new node and index where it will be produced std::unordered_map nameToNodeAndIndexMap; + std::unordered_map nodeOutputShapes; + // Map from Lotus node argument names to input indices of the fused kernel node. std::unordered_map nameToDmlFusedNodeInputIndex; - for (size_t inputIndex = 0; inputIndex < subGraphInputArgNames.size(); ++inputIndex) + for (size_t inputIndex = 0; inputIndex < subgraphInputs.size(); ++inputIndex) { - const onnxruntime::NodeArg* graphInput = graph.GetNodeArg(subGraphInputArgNames[inputIndex]); + const onnxruntime::NodeArg* graphInput = subgraphInputs[inputIndex]; if (!graphInput) { @@ -196,13 +198,11 @@ namespace Dml::GraphDescBuilder const uint32_t minNodeCountToReuseCommandList = 5; bool reuseCommandList = false; - if (indexedSubGraph.nodes.size() >= minNodeCountToReuseCommandList) + if (subgraphNodes.size() >= minNodeCountToReuseCommandList) { reuseCommandList = true; } - auto modelPath = graph.ModelPath(); - auto constantCpuGraphInputGetter = [&isInitializerTransferable, &modelPath](const std::string& argName) { ComPtr tensorWrapper; @@ -219,9 +219,11 @@ namespace Dml::GraphDescBuilder // Iterate through each node and create a corresponding node in the new graph // We can iterate the nodes in any order because the edge connectivity will take care of the topological order - for (size_t sortedNodeIndex : indexedSubGraph.nodes) + std::unordered_map> inferredOutputShapes; + + for (const onnxruntime::Node* subgraphNode : subgraphNodes) { - const onnxruntime::Node& node = *graph.GetNode(sortedNodeIndex); + const onnxruntime::Node& node = *subgraphNode; const GraphNodeProperties& graphNodeProps = graphNodePropertyMap.find(GetUniqueNodeName(node))->second; const auto& requiredConstantCpuInputs = graphNodeProps.internalRegInfo->requiredConstantCpuInputs; @@ -244,14 +246,45 @@ namespace Dml::GraphDescBuilder return tensor; }; + EdgeShapes inputShapesOverrides(node.InputDefs().size()); + + // Override the input shapes with shapes that were previously inferred + for (int inputIndex = 0; inputIndex < node.InputDefs().size(); ++inputIndex) + { + auto inputDef = node.InputDefs()[inputIndex]; + + auto outputShapesIter = inferredOutputShapes.find(inputDef->Name()); + if (outputShapesIter != inferredOutputShapes.end()) + { + inputShapesOverrides.GetMutableShape(inputIndex) = outputShapesIter->second; + } + else if (inputDef->HasTensorOrScalarShape()) + { + for (int i = 0; i < inputDef->Shape()->dim_size(); ++i) + { + ORT_THROW_HR_IF(E_INVALIDARG, !inputDef->Shape()->dim(i).has_dim_value()); + inputShapesOverrides.GetMutableShape(inputIndex).push_back(gsl::narrow_cast(inputDef->Shape()->dim(i).dim_value())); + } + } + } + + EdgeShapes outputShapes; DmlGraphNodeCreateInfo graphNodeCreateInfo; graphNodeProps.internalRegInfo->graphNodeFactoryRegistration->factory( node, constantCpuNodeInputGetter, executionHandle, + &inputShapesOverrides, + /*out*/ &outputShapes, /*out*/ &graphNodeCreateInfo ); + ORT_THROW_HR_IF(E_UNEXPECTED, outputShapes.EdgeCount() != node.OutputDefs().size()); + for (int i = 0; i < node.OutputDefs().size(); ++i) + { + inferredOutputShapes[node.OutputDefs()[i]->Name()] = outputShapes.GetShape(i); + } + // Create a map between operatorGraphNodeIndex to mainGraphNodeIndex. std::unordered_map operatorGraphNodeIndexToMainGraphNodeIndexMap; uint32_t graphNodeCount = gsl::narrow_cast(graphNodes.size()); @@ -347,6 +380,8 @@ namespace Dml::GraphDescBuilder operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphOutputEdge.FromNodeIndex], operatorGraphOutputEdge.FromNodeOutputIndex }; + + nodeOutputShapes[arg->Name()] = outputShapes; } } @@ -367,10 +402,12 @@ namespace Dml::GraphDescBuilder } } + EdgeShapes graphOutputShapes(subgraphOutputs.size()); + // Add graph output nodes, which might be in a different order from the encapsulating node - for (size_t outputIndex = 0; outputIndex < subGraphOutputArgNames.size(); ++outputIndex) + for (size_t outputIndex = 0; outputIndex < subgraphOutputs.size(); ++outputIndex) { - const onnxruntime::NodeArg* graphOutput = graph.GetNodeArg(subGraphOutputArgNames[outputIndex]); + const onnxruntime::NodeArg* graphOutput = subgraphOutputs[outputIndex]; ORT_THROW_HR_IF_NULL_MSG(E_POINTER, graphOutput, "FusedNode's nodeArgList does not contain one of the nodeArg"); const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name()); @@ -380,6 +417,7 @@ namespace Dml::GraphDescBuilder edge.FromNodeOutputIndex = outputNodeAndIndex.targetIndex; edge.GraphOutputIndex = gsl::narrow_cast(outputIndex); graphOutputEdges.push_back(edge); + graphOutputShapes.GetMutableShape(outputIndex) = nodeOutputShapes[graphOutput->Name()].GetShape(outputNodeAndIndex.targetIndex); } RemoveUnconnectedNodes(graphNodes, graphInputEdges, graphIntermediateEdges, graphOutputEdges); @@ -390,6 +428,7 @@ namespace Dml::GraphDescBuilder graphDesc.outputEdges = std::move(graphOutputEdges); graphDesc.intermediateEdges = std::move(graphIntermediateEdges); graphDesc.reuseCommandList = reuseCommandList; + graphDesc.outputShapes = std::move(graphOutputShapes); return graphDesc; } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h index 5c04962e55..0039678c00 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h @@ -9,10 +9,10 @@ namespace Dml { struct GraphNodeProperties { - std::shared_ptr + std::shared_ptr internalRegInfo; - // These are currently passed from the partitioning step since the only DML operators current + // These are currently passed from the partitioning step since the only DML operators current // supporting graph nodes don't customize the order of edges or shapes, other than coercing // dimension count. This will change as the supported set of operators as graph nodes increases. Windows::AI::MachineLearning::Adapter::EdgeShapes inputShapes; @@ -38,16 +38,19 @@ namespace Dml std::vector outputEdges; std::vector intermediateEdges; bool reuseCommandList; + Windows::AI::MachineLearning::Adapter::EdgeShapes outputShapes; }; GraphDesc BuildGraphDesc( const uint8_t* isConstGpuGraphInput, const size_t isConstGpuGraphInputCount, const std::unordered_map>& isInitializerTransferable, - const onnxruntime::Graph& graph, - const onnxruntime::IndexedSubGraph& indexedSubGraph, const std::unordered_map& graphNodePropertyMap, IDMLDevice* device, - const void* executionHandle); + const void* executionHandle, + const onnxruntime::Path& modelPath, + gsl::span subgraphNodes, + gsl::span subgraphInputs, + gsl::span subgraphOutputs); } -} \ No newline at end of file +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp index 2c8d4e4459..f7a4743801 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp @@ -151,6 +151,8 @@ namespace Dml _In_opt_ const std::unordered_map* nodeNameToPartitionMap, _Inout_ std::unordered_map& dmlNodePropertyMap, _Inout_ std::unordered_set& requiredInitializerMap, + _Inout_ std::unordered_set& dynamicCpuInputMap, + bool allowDmlGraphDynamicShapes, _Out_ bool* isDmlGraphNode ) { @@ -172,36 +174,68 @@ namespace Dml if (internalRegInfo && internalRegInfo->graphNodeFactoryRegistration) { - bool requiredCpuInputsConstant = true; - for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs) + if (allowDmlGraphDynamicShapes) { - if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists()) + for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs) { - continue; - } + if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists()) + { + continue; + } - const onnx::TensorProto* tensor = nullptr; - const std::string& inputName = node.InputDefs()[inputIndex]->Name(); + const onnx::TensorProto* tensor = nullptr; + const std::string& inputName = node.InputDefs()[inputIndex]->Name(); - if (!graph.GetInitializedTensor(inputName, tensor)) - { - requiredCpuInputsConstant = false; - break; + if (graph.GetInitializedTensor(inputName, tensor)) + { + requiredInitializerMap.insert(inputName); + } + else + { + dynamicCpuInputMap.insert(inputName); + } } - requiredInitializerMap.insert(inputName); + std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; + if (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size()) + { + *isDmlGraphNode = true; + graphNodeProperty.first->second.internalRegInfo = internalRegInfo; + } } - - std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; - if (requiredCpuInputsConstant && - TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) && - !ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes, internalRegInfo->requiredConstantCpuInputs) && - TryGetStaticOutputShapes(node, graphNodeProperty.first->second.outputShapes) && - !ContainsEmptyDimensions(graphNodeProperty.first->second.outputShapes, internalRegInfo->requiredConstantCpuInputs) && - (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())) + else { - *isDmlGraphNode = true; - graphNodeProperty.first->second.internalRegInfo = internalRegInfo; + bool requiredCpuInputsConstant = true; + for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs) + { + if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists()) + { + continue; + } + + const onnx::TensorProto* tensor = nullptr; + const std::string& inputName = node.InputDefs()[inputIndex]->Name(); + + if (!graph.GetInitializedTensor(inputName, tensor)) + { + requiredCpuInputsConstant = false; + break; + } + + requiredInitializerMap.insert(inputName); + } + + std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; + if (requiredCpuInputsConstant && + TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) && + !ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes, internalRegInfo->requiredConstantCpuInputs) && + TryGetStaticOutputShapes(node, graphNodeProperty.first->second.outputShapes) && + !ContainsEmptyDimensions(graphNodeProperty.first->second.outputShapes, internalRegInfo->requiredConstantCpuInputs) && + (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())) + { + *isDmlGraphNode = true; + graphNodeProperty.first->second.internalRegInfo = internalRegInfo; + } } } } @@ -345,13 +379,8 @@ namespace Dml // Whether any operator in the model contains a subgraph. This is true // if the graph being partitioned is itself within a subgraph, or contains // an operator with a subgraph. - bool ModelUsesSubgraph(const onnxruntime::GraphViewer& graph) + bool ContainsSubgraph(const onnxruntime::GraphViewer& graph) { - if (graph.IsSubgraph()) - { - return true; - } - const std::vector& toplogicalOrder = graph.GetNodesInTopologicalOrder(); for (size_t nodeIndex : toplogicalOrder) @@ -384,7 +413,10 @@ namespace Dml uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, - gsl::span additionalSplittingNodes) + std::unordered_set& dynamicCpuInputMap, + gsl::span additionalSplittingNodes, + const std::unordered_map& implicitInputs, + bool allowDmlGraphDynamicShapes) { // Nodes are uniquely identified by the name of their first output argument std::vector> partitions; @@ -419,7 +451,7 @@ namespace Dml } // Check whether this graph is a subgraph, or contains any node with a subgraph. - bool modelUsesSubgraph = ModelUsesSubgraph(graph); + bool containsSubgraph = ContainsSubgraph(graph); uint32_t splittingNodeIndex = 0; @@ -447,6 +479,8 @@ namespace Dml &nodeNameToPartitionMap, graphNodePropertyMap, requiredInitializerMap, + dynamicCpuInputMap, + allowDmlGraphDynamicShapes, /*out*/ &isDmlGraphNode ); } @@ -454,10 +488,10 @@ namespace Dml // Add a unique partition if graph node usage is not supported. // // Partitioning is disabled in models with subgraphs to work around issues with implicit inputs. - // The partitioning algorithm does not currently consider such inputs. Transfering shared initializers + // The partitioning algorithm does not currently consider such inputs. Transferring shared initializers // for partitions could also cause problems. Note, operators with subgraphs are currently not efficient // anyhow due to CPU/GPU copies. - if (modelUsesSubgraph || !isDmlGraphNode) + if (containsSubgraph || !isDmlGraphNode) { partitions.push_back(CreatePartitionAndFinalizeInputs(node, isDmlNode, false, nodeNameToPartitionMap)); continue; @@ -505,7 +539,7 @@ namespace Dml firstNonFinalInputPartition->AddInput(arg->Name()); } - if (graphInputs.find(arg->Name()) != graphInputs.end()) + if (graphInputs.find(arg->Name()) != graphInputs.end() || implicitInputs.find(arg->Name()) != implicitInputs.end()) { firstNonFinalInputPartition->AddInput(arg->Name()); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h index 990ba00fc4..3bddb5ae16 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h @@ -3,6 +3,8 @@ #pragma once +#include +#include #include "core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h" namespace Dml @@ -48,5 +50,8 @@ namespace Dml uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, - gsl::span additionalSplittingNodes); + std::unordered_set& dynamicCpuInputMap, + gsl::span additionalSplittingNodes, + const std::unordered_map& implicitInputs, + bool allowDmlGraphDynamicShapes); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h index d7a0a607cd..a8a6d6745e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h @@ -2,8 +2,15 @@ // Licensed under the MIT License. #pragma once + +#include + #include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" +interface IDMLCompiledOperator; +struct DML_BUFFER_BINDING; +struct DML_BINDING_DESC; + namespace Dml { struct Binding diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index 6cd10e14e0..4deec620fe 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1356,13 +1356,14 @@ namespace Windows::AI::MachineLearning::Adapter const onnxruntime::OpNodeProtoHelper* protoHelper, const void* executionHandle, bool isInternalOperator, + const EdgeShapes* inputShapesOverrides, const EdgeShapes* inferredOutputShapes, const AttributeMap* defaultAttributes, DmlGraphNodeCreateInfo* graphNodeCreateInfo, gsl::span requiredConstantCpuInputs, MLOperatorTensorGetter& constantInputGetter ) - : OpNodeInfoWrapper(protoHelper, nullptr, defaultAttributes, requiredConstantCpuInputs, constantInputGetter, nullptr), + : OpNodeInfoWrapper(protoHelper, inputShapesOverrides, defaultAttributes, requiredConstantCpuInputs, constantInputGetter, nullptr), m_inferredOutputShapes(inferredOutputShapes), m_internalOperator(isInternalOperator), m_graphNodeCreateInfo(graphNodeCreateInfo) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index a7f8bebb2d..913997ff4a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -4,6 +4,7 @@ #pragma once #include "core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h" #include "core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h" #include "core/framework/op_kernel.h" #include "core/framework/customregistry.h" #include "core/framework/tensorprotoutils.h" @@ -93,42 +94,6 @@ struct AttributeValue using AttributeMap = std::map; -// Encapsulation of shapes across different edges of an operator. Non-tensor -// edges and unused edges have an empty array of dimensions. -class EdgeShapes -{ -public: - EdgeShapes() = default; - - EdgeShapes(size_t count) : m_shapes(count) {} - - const std::vector& GetShape(size_t edgeIndex) const - { - return m_shapes[edgeIndex]; - } - - std::vector& GetMutableShape(size_t edgeIndex) - { - return m_shapes[edgeIndex]; - } - - size_t EdgeCount() const { return m_shapes.size(); } - - void Reset(size_t edge_count) - { - m_shapes.clear(); - m_shapes.resize(edge_count); - } - - bool operator!=(const EdgeShapes& other) const noexcept - { - return (m_shapes != other.m_shapes); - } - - private: - std::vector> m_shapes; -}; - // Base class for ABI objects which may be "Closed", at which point calls will predictably // fail or return a dummy value. This is used for transient ABI context objects which // are passed to methods on kernel or inferencers, and which wrap Lotus objects whose lifetimes @@ -434,6 +399,7 @@ class DmlGraphOpKernelInfoWrapper : public OpNodeInfoWrapper< const onnxruntime::OpNodeProtoHelper * protoHelper, const void* executionHandle, bool isInternalOperator, + const EdgeShapes* inputShapesOverrides, const EdgeShapes* inferredOutputShapes, const AttributeMap* defaultAttributes, DmlGraphNodeCreateInfo* graphNodeCreateInfo, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp index 9c1a7baeaa..03500d0ee8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp @@ -205,12 +205,34 @@ class DmlOperatorMultiHeadAttention : public DmlOperator else { const auto keyPaddingMaskTensorShape = m_inputTensorDescs[dmlMaskIndex].GetSizes(); - ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape.size() == 2); + size_t maskDimCount = keyPaddingMaskTensorShape.size(); + ML_CHECK_VALID_ARGUMENT(maskDimCount >= 2 || maskDimCount <= 4); ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[0] == batchSize); - ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == kvSequenceLength); - const uint32_t actualShape[4] = {batchSize, 1, 1, kvSequenceLength}; - const uint32_t desiredShape[4] = {batchSize, numHeads, sequenceLength, kvSequenceLength}; + std::array actualShape{}; + std::array desiredShape{}; + + if (maskDimCount == 2) + { + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == kvSequenceLength); + actualShape = {batchSize, 1, 1, kvSequenceLength}; + desiredShape = {batchSize, numHeads, sequenceLength, kvSequenceLength}; + } + else if (maskDimCount == 3) + { + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == sequenceLength); + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[2] == totalSequenceLength); + actualShape = {batchSize, 1, sequenceLength, totalSequenceLength}; + desiredShape = {batchSize, numHeads, sequenceLength, totalSequenceLength}; + } + else if (maskDimCount == 4) + { + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == numHeads); + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[2] == sequenceLength); + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[3] == totalSequenceLength); + actualShape = {batchSize, numHeads, sequenceLength, totalSequenceLength}; + desiredShape = {batchSize, numHeads, sequenceLength, totalSequenceLength}; + } m_inputTensorDescs[dmlMaskIndex] = TensorDesc::ConstructBroadcastedTensorDesc( m_inputTensorDescs[dmlMaskIndex].GetMlOperatorDataType(), diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp new file mode 100644 index 0000000000..30c339b845 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp @@ -0,0 +1,436 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +// This operator is easier to understand by looking at a python implementation of the non-interleaved version: +// +// def rotate_half(x): +// """Rotates half the hidden dims of the input.""" +// half_dim = x.shape[-1] // 2 +// x1 = x[..., :half_dim] +// x2 = x[..., half_dim:] +// return np.concatenate((-x2, x1), dim=-1) +// +// +// def apply_rope(x, cos, sin, position_ids): +// cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] +// sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] +// x_embed = (x * cos) + (rotate_half(x) * sin) +// return x_embed +// +// For the non-interleaved version, we multiply the cos cache by the non-rotated input tensor while we multiply the sin cache +// by the rotated input tensor. Rotating the tensor means slicing it in half on the head dimension and swapping the 2 halves. +// +// The interleaved version is very similar but instead of swapping 2 halves, we swap every pair of adjacent elements and we swap +// the sign of every adjacent element. + +namespace Dml +{ +class DmlOperatorRotaryEmbedding : public DmlOperator +{ +public: + DmlOperatorRotaryEmbedding(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo) + { + enum InputIndex : uint32_t + { + inputDataIndex, + positionIdsIndex, + cosCacheIndex, + sinCacheIndex, + }; + + ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 4); + ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); + + // When positionIds is a scalar, it represents the start offset for each sequence + const bool positionIdsIsOffset = kernelInfo.GetInputTensorDimensionCount(positionIdsIndex) == 1; + + Initialize(kernelInfo); + + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[inputDataIndex].GetDimensionCount() == 4); + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[positionIdsIndex].GetDimensionCount() == 4); + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[cosCacheIndex].GetDimensionCount() == 4); + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[sinCacheIndex].GetDimensionCount() == 4); + + ML_CHECK_VALID_ARGUMENT(m_outputTensorDescs[0].GetDimensionCount() == 4); + + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[cosCacheIndex].GetSizes() == m_inputTensorDescs[sinCacheIndex].GetSizes()); + const uint32_t headSize = m_inputTensorDescs[cosCacheIndex].GetSizes().back() * 2; + + // The last dimension of the data is the hidden size, so it must be divisible by the head size + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[inputDataIndex].GetSizes().back() % headSize == 0); + + // We resize the data to be of shape [batchSize, sequenceLength, numHeads, headSize] + const auto inputDataSizes = m_inputTensorDescs[inputDataIndex].GetSizes(); + const uint32_t batchSize = inputDataSizes[1]; + const uint32_t sequenceLength = inputDataSizes[2]; + const uint32_t numHeads = inputDataSizes[3] / headSize; + + const auto cosCacheSizes = m_inputTensorDescs[cosCacheIndex].GetSizes(); + const uint32_t maxSequenceLength = cosCacheSizes[cosCacheSizes.size() - 2]; + + if (sequenceLength > maxSequenceLength) + { + ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); + } + + const bool interleaved = gsl::narrow_cast(kernelInfo.GetOptionalAttribute(AttrName::Interleaved, 0)); + + std::vector inputDescs = GetDmlInputDescs(); + const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType; + + // Splitting the hiddenSize into numHeads and headSize dimensions makes it easier for DML to handle + const std::array inputOutputShape = {batchSize, sequenceLength, numHeads, headSize}; + TensorDesc inputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape); + const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc(); + + // Copy the input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase. + DML_SCALE_BIAS scaleBias = {1.0f, 0.0f}; + + DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC copyInputDesc{}; + copyInputDesc.InputTensor = &inputOutputDmlTensorDesc; + copyInputDesc.OutputTensor = &inputOutputDmlTensorDesc; + copyInputDesc.ScaleBias = &scaleBias; + const DML_OPERATOR_DESC copyInputDmlDesc = {DML_OPERATOR_ELEMENT_WISE_IDENTITY, ©InputDesc}; + + // Split the input data into 2 equal parts + const std::vector inputDataTensorShape = interleaved + ? std::vector({batchSize, sequenceLength, numHeads, headSize / 2, 2}) + : std::vector({batchSize, sequenceLength, numHeads, 2, headSize / 2}); + + const std::vector splitInputDataTensorShape = interleaved + ? std::vector({batchSize, sequenceLength, numHeads, headSize / 2, 1}) + : std::vector({batchSize, sequenceLength, numHeads, 1, headSize / 2}); + + TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); + const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc(); + + TensorDesc splitInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputDataTensorShape); + const std::array splitInputDataDmlTensorDescs = {splitInputDataTensorDesc.GetDmlDesc(), splitInputDataTensorDesc.GetDmlDesc()}; + + DML_SPLIT_OPERATOR_DESC splitInputDesc{}; + splitInputDesc.InputTensor = &inputDataDmlTensorDesc; + splitInputDesc.OutputTensors = splitInputDataDmlTensorDescs.data(); + splitInputDesc.OutputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); + splitInputDesc.Axis = interleaved + ? gsl::narrow_cast(splitInputDataTensorShape.size()) - 1 + : gsl::narrow_cast(splitInputDataTensorShape.size()) - 2; + + const DML_OPERATOR_DESC splitInputDmlDesc = {DML_OPERATOR_SPLIT, &splitInputDesc}; + + // Swap the 2 halves and join them together + DML_JOIN_OPERATOR_DESC joinInputDesc{}; + joinInputDesc.InputTensors = splitInputDataDmlTensorDescs.data(); + joinInputDesc.OutputTensor = &inputDataDmlTensorDesc; + joinInputDesc.Axis = splitInputDesc.Axis; + joinInputDesc.InputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); + const DML_OPERATOR_DESC joinInputDmlDesc = {DML_OPERATOR_JOIN, &joinInputDesc}; + + // We generate a sequence from 0 to sequenceLength and add the offset to it + const std::array positionIdsRangeShape = {1, 1, 1, sequenceLength}; + auto positionIdsDataType = kernelInfo.GetInputEdgeDescription(positionIdsIndex).tensorDataType; + TensorDesc positionIdsRangeTensorDesc = TensorDesc::ConstructDefaultTensorDesc(positionIdsDataType, positionIdsRangeShape); + const DML_TENSOR_DESC positionIdsRangeDmlTensorDesc = positionIdsRangeTensorDesc.GetDmlDesc(); + + const std::array broadcastedPositionIdsRangeShape = {1, 1, batchSize, sequenceLength}; + TensorDesc broadcastedPositionIdsRangeTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(positionIdsDataType, broadcastedPositionIdsRangeShape, positionIdsRangeShape); + const DML_TENSOR_DESC broadcastedPositionIdsRangeDmlTensorDesc = broadcastedPositionIdsRangeTensorDesc.GetDmlDesc(); + + const std::array broadcastedOffsetShape = {1, 1, batchSize, sequenceLength}; + TensorDesc broadcastedOffsetTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(positionIdsDataType, broadcastedOffsetShape, m_inputTensorDescs[positionIdsIndex].GetSizes()); + const DML_TENSOR_DESC broadcastedOffsetDmlTensorDesc = broadcastedOffsetTensorDesc.GetDmlDesc(); + + TensorDesc offsetPositionIdsTensorDesc = TensorDesc::ConstructDefaultTensorDesc(positionIdsDataType, broadcastedOffsetShape); + const DML_TENSOR_DESC offsetPositionIdsRangeDmlTensorDesc = offsetPositionIdsTensorDesc.GetDmlDesc(); + + DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC positionIdsRange{}; + DML_ELEMENT_WISE_ADD_OPERATOR_DESC positionIdsAddOffset{}; + if (positionIdsIsOffset) + { + ML_CHECK_VALID_ARGUMENT(positionIdsDataType == MLOperatorTensorDataType::Int64); + positionIdsRange.ValueDataType = DML_TENSOR_DATA_TYPE_INT64; + positionIdsRange.ValueDelta.Int64 = 1; + positionIdsRange.OutputTensor = &positionIdsRangeDmlTensorDesc; + + positionIdsAddOffset.ATensor = &broadcastedPositionIdsRangeDmlTensorDesc; + positionIdsAddOffset.BTensor = &broadcastedOffsetDmlTensorDesc; + positionIdsAddOffset.OutputTensor = &offsetPositionIdsRangeDmlTensorDesc; + } + const DML_OPERATOR_DESC positionIdsRangeDmlDesc = {DML_OPERATOR_FILL_VALUE_SEQUENCE, &positionIdsRange}; + const DML_OPERATOR_DESC positionIdsAddOffsetDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &positionIdsAddOffset}; + + // Gather the cos/sin values based on the position ids + const std::array gatheredCosSinShape = {1, batchSize, sequenceLength, headSize / 2}; + TensorDesc gatheredCosSinTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, gatheredCosSinShape); + const DML_TENSOR_DESC gatheredCosSinDmlTensorDesc = gatheredCosSinTensorDesc.GetDmlDesc(); + + DML_GATHER_OPERATOR_DESC gatherCosSinDesc{}; + gatherCosSinDesc.InputTensor = &inputDescs[cosCacheIndex]; + gatherCosSinDesc.IndicesTensor = positionIdsIsOffset ? &offsetPositionIdsRangeDmlTensorDesc : &inputDescs[positionIdsIndex]; + gatherCosSinDesc.OutputTensor = &gatheredCosSinDmlTensorDesc; + gatherCosSinDesc.Axis = 2; + gatherCosSinDesc.IndexDimensions = 2; + const DML_OPERATOR_DESC gatherCosSinDmlDesc {DML_OPERATOR_GATHER, &gatherCosSinDesc}; + + // After gathering cos/sin, reshape and broadcast them to match the number of heads of the input data + const std::vector reshapedCosSinShape = interleaved + ? std::vector({batchSize, sequenceLength, 1, headSize / 2, 1}) + : std::vector({batchSize, sequenceLength, 1, 1, headSize / 2}); + TensorDesc broadcastedCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, inputDataTensorShape, reshapedCosSinShape); + const DML_TENSOR_DESC broadcastedCosSinDmlTensorDesc = broadcastedCosSinTensorDesc.GetDmlDesc(); + + // Create a vector that contains the sign values {-1, 1} + const std::array signTensorShape = {2}; + TensorDesc signTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, signTensorShape); + const DML_TENSOR_DESC signDmlTensorDesc = signTensorDesc.GetDmlDesc(); + + DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC signRange{}; + signRange.OutputTensor = &signDmlTensorDesc; + if (dataType == MLOperatorTensorDataType::Float16) + { + const auto valueStart = static_cast(-1.0f); + const auto valueDelta = static_cast(2.0f); + memcpy(signRange.ValueStart.Bytes, reinterpret_cast(&valueStart), sizeof(valueStart)); + memcpy(signRange.ValueDelta.Bytes, reinterpret_cast(&valueDelta), sizeof(valueDelta)); + signRange.ValueDataType = DML_TENSOR_DATA_TYPE_FLOAT16; + } + else + { + ML_CHECK_VALID_ARGUMENT(dataType == MLOperatorTensorDataType::Float); + signRange.ValueStart.Float32 = -1.0f; + signRange.ValueDelta.Float32 = 2.0f; + signRange.ValueDataType = DML_TENSOR_DATA_TYPE_FLOAT32; + } + const DML_OPERATOR_DESC signRangeDmlDesc = {DML_OPERATOR_FILL_VALUE_SEQUENCE, &signRange}; + + // Multiply the broadcasted sign values with the rotated input + const std::vector reshapedSignShape = interleaved + ? std::vector({1, 1, 1, 1, 2}) + : std::vector({1, 1, 1, 2, 1}); + TensorDesc broadcastedSignCosSinTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(dataType, inputDataTensorShape, reshapedSignShape); + const DML_TENSOR_DESC broadcastedSignDmlTensorDesc = broadcastedSignCosSinTensorDesc.GetDmlDesc(); + + DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulSignDesc{}; + mulSignDesc.ATensor = &inputDataDmlTensorDesc; + mulSignDesc.BTensor = &broadcastedSignDmlTensorDesc; + mulSignDesc.OutputTensor = &inputDataDmlTensorDesc; + const DML_OPERATOR_DESC mulSignDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulSignDesc}; + + // Multiply the non-rotated data with the cos and the rotated data with the sin + DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulCosSinDesc{}; + mulCosSinDesc.ATensor = &inputDataDmlTensorDesc; + mulCosSinDesc.BTensor = &broadcastedCosSinDmlTensorDesc; + mulCosSinDesc.OutputTensor = &inputDataDmlTensorDesc; + const DML_OPERATOR_DESC mulCosSinDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulCosSinDesc}; + + // Add the multiplied cos and sin values together + DML_ELEMENT_WISE_ADD_OPERATOR_DESC addDesc{}; + addDesc.ATensor = &inputOutputDmlTensorDesc; + addDesc.BTensor = &inputOutputDmlTensorDesc; + addDesc.OutputTensor = &inputOutputDmlTensorDesc; + const DML_OPERATOR_DESC addDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addDesc}; + + // Construct the graph + std::vector inputEdges; + std::vector intermediateEdges; + std::vector outputEdges; + + std::vector opDescs = { + ©InputDmlDesc, // Copy the input data to preseve the real input shape + &splitInputDmlDesc, // Split the input data + &gatherCosSinDmlDesc, // Gather cos + &gatherCosSinDmlDesc, // Gather sin + &signRangeDmlDesc, // Generate the signs + + &joinInputDmlDesc, // Join the split data + &mulCosSinDmlDesc, // Multiply cos with the non-rotated data + &mulCosSinDmlDesc, // Multiply sin with the rotated data + &mulSignDmlDesc, // Multiply the sign with the rotated data + &addDmlDesc, // Add the rotated cos and non-rotated sin parts together + }; + + enum NodeIndex : uint32_t + { + copyInputOpIndex, + splitInputOpIndex, + gatherCosOpIndex, + gatherSinOpIndex, + signRangeOpIndex, + + joinInputOpIndex, + mulCosOpIndex, + mulSinOpIndex, + mulSignOpIndex, + addOpIndex, + + // The following indices are optional + positionIdsRangeOpIndex, + positionIdsAddOffsetOpIndex, + }; + + if (positionIdsIsOffset) + { + opDescs.push_back(&positionIdsRangeDmlDesc); + opDescs.push_back(&positionIdsAddOffsetDmlDesc); + + DML_INPUT_GRAPH_EDGE_DESC positionIdsToAddOffsetEdge = {}; + positionIdsToAddOffsetEdge.GraphInputIndex = positionIdsIndex; + positionIdsToAddOffsetEdge.ToNodeIndex = positionIdsAddOffsetOpIndex; + positionIdsToAddOffsetEdge.ToNodeInputIndex = 1; + inputEdges.push_back(positionIdsToAddOffsetEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsOffsetToAddOffsetEdge = {}; + positionIdsOffsetToAddOffsetEdge.FromNodeIndex = positionIdsRangeOpIndex; + positionIdsOffsetToAddOffsetEdge.FromNodeOutputIndex = 0; + positionIdsOffsetToAddOffsetEdge.ToNodeIndex = positionIdsAddOffsetOpIndex; + positionIdsOffsetToAddOffsetEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(positionIdsOffsetToAddOffsetEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsAddOffsetToGatherCosEdge = {}; + positionIdsAddOffsetToGatherCosEdge.FromNodeIndex = positionIdsAddOffsetOpIndex; + positionIdsAddOffsetToGatherCosEdge.FromNodeOutputIndex = 0; + positionIdsAddOffsetToGatherCosEdge.ToNodeIndex = gatherCosOpIndex; + positionIdsAddOffsetToGatherCosEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(positionIdsAddOffsetToGatherCosEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC positionIdsAddOffsetToGatherSinEdge = {}; + positionIdsAddOffsetToGatherSinEdge.FromNodeIndex = positionIdsAddOffsetOpIndex; + positionIdsAddOffsetToGatherSinEdge.FromNodeOutputIndex = 0; + positionIdsAddOffsetToGatherSinEdge.ToNodeIndex = gatherSinOpIndex; + positionIdsAddOffsetToGatherSinEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(positionIdsAddOffsetToGatherSinEdge); + } + else + { + DML_INPUT_GRAPH_EDGE_DESC positionIdsToGatherCosEdge = {}; + positionIdsToGatherCosEdge.GraphInputIndex = positionIdsIndex; + positionIdsToGatherCosEdge.ToNodeIndex = gatherCosOpIndex; + positionIdsToGatherCosEdge.ToNodeInputIndex = 1; + inputEdges.push_back(positionIdsToGatherCosEdge); + + DML_INPUT_GRAPH_EDGE_DESC positionIdsToGatherSinEdge = {}; + positionIdsToGatherSinEdge.GraphInputIndex = positionIdsIndex; + positionIdsToGatherSinEdge.ToNodeIndex = gatherSinOpIndex; + positionIdsToGatherSinEdge.ToNodeInputIndex = 1; + inputEdges.push_back(positionIdsToGatherSinEdge); + } + + DML_INPUT_GRAPH_EDGE_DESC inputToCopyInputEdge = {}; + inputToCopyInputEdge.GraphInputIndex = inputDataIndex; + inputToCopyInputEdge.ToNodeIndex = copyInputOpIndex; + inputToCopyInputEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputToCopyInputEdge); + + DML_INPUT_GRAPH_EDGE_DESC cosToGatherEdge = {}; + cosToGatherEdge.GraphInputIndex = cosCacheIndex; + cosToGatherEdge.ToNodeIndex = gatherCosOpIndex; + cosToGatherEdge.ToNodeInputIndex = 0; + inputEdges.push_back(cosToGatherEdge); + + DML_INPUT_GRAPH_EDGE_DESC sinToGatherEdge = {}; + sinToGatherEdge.GraphInputIndex = sinCacheIndex; + sinToGatherEdge.ToNodeIndex = gatherSinOpIndex; + sinToGatherEdge.ToNodeInputIndex = 0; + inputEdges.push_back(sinToGatherEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC inputToSplitEdge = {}; + inputToSplitEdge.FromNodeIndex = copyInputOpIndex; + inputToSplitEdge.FromNodeOutputIndex = 0; + inputToSplitEdge.ToNodeIndex = splitInputOpIndex; + inputToSplitEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(inputToSplitEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC nonRotatedDataToMulEdge = {}; + nonRotatedDataToMulEdge.FromNodeIndex = copyInputOpIndex; + nonRotatedDataToMulEdge.FromNodeOutputIndex = 0; + nonRotatedDataToMulEdge.ToNodeIndex = mulCosOpIndex; + nonRotatedDataToMulEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(nonRotatedDataToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC secondHalfDataToJoinEdge = {}; + secondHalfDataToJoinEdge.FromNodeIndex = splitInputOpIndex; + secondHalfDataToJoinEdge.FromNodeOutputIndex = 1; + secondHalfDataToJoinEdge.ToNodeIndex = joinInputOpIndex; + secondHalfDataToJoinEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(secondHalfDataToJoinEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC firstHalfDataToJoinEdge = {}; + firstHalfDataToJoinEdge.FromNodeIndex = splitInputOpIndex; + firstHalfDataToJoinEdge.FromNodeOutputIndex = 0; + firstHalfDataToJoinEdge.ToNodeIndex = joinInputOpIndex; + firstHalfDataToJoinEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(firstHalfDataToJoinEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC cosToMulEdge = {}; + cosToMulEdge.FromNodeIndex = gatherCosOpIndex; + cosToMulEdge.FromNodeOutputIndex = 0; + cosToMulEdge.ToNodeIndex = mulCosOpIndex; + cosToMulEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(cosToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC rotatedDataToMulEdge = {}; + rotatedDataToMulEdge.FromNodeIndex = joinInputOpIndex; + rotatedDataToMulEdge.FromNodeOutputIndex = 0; + rotatedDataToMulEdge.ToNodeIndex = mulSinOpIndex; + rotatedDataToMulEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(rotatedDataToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC sinToMulEdge = {}; + sinToMulEdge.FromNodeIndex = gatherSinOpIndex; + sinToMulEdge.FromNodeOutputIndex = 0; + sinToMulEdge.ToNodeIndex = mulSinOpIndex; + sinToMulEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(sinToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC rotatedSinToMulEdge = {}; + rotatedSinToMulEdge.FromNodeIndex = mulSinOpIndex; + rotatedSinToMulEdge.FromNodeOutputIndex = 0; + rotatedSinToMulEdge.ToNodeIndex = mulSignOpIndex; + rotatedSinToMulEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(rotatedSinToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC signToMulEdge = {}; + signToMulEdge.FromNodeIndex = signRangeOpIndex; + signToMulEdge.FromNodeOutputIndex = 0; + signToMulEdge.ToNodeIndex = mulSignOpIndex; + signToMulEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(signToMulEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC nonRotatedCosToAddEdge = {}; + nonRotatedCosToAddEdge.FromNodeIndex = mulCosOpIndex; + nonRotatedCosToAddEdge.FromNodeOutputIndex = 0; + nonRotatedCosToAddEdge.ToNodeIndex = addOpIndex; + nonRotatedCosToAddEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(nonRotatedCosToAddEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC rotatedSinToAddEdge = {}; + rotatedSinToAddEdge.FromNodeIndex = mulSignOpIndex; + rotatedSinToAddEdge.FromNodeOutputIndex = 0; + rotatedSinToAddEdge.ToNodeIndex = addOpIndex; + rotatedSinToAddEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(rotatedSinToAddEdge); + + DML_OUTPUT_GRAPH_EDGE_DESC addToOutputEdge = {}; + addToOutputEdge.FromNodeIndex = addOpIndex; + addToOutputEdge.FromNodeOutputIndex = 0; + addToOutputEdge.GraphOutputIndex = 0; + outputEdges.push_back(addToOutputEdge); + + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); + operatorGraphDesc.inputEdges = inputEdges.data(); + operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size()); + operatorGraphDesc.intermediateEdges = intermediateEdges.data(); + operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); + operatorGraphDesc.outputEdges = outputEdges.data(); + operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); + operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo); + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(RotaryEmbedding, DmlOperatorRotaryEmbedding); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 30bc6e5e27..28360f09bc 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -510,6 +510,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(BitwiseAnd); DML_OP_EXTERN_CREATION_FUNCTION(BitwiseOr); DML_OP_EXTERN_CREATION_FUNCTION(BitwiseXor); DML_OP_EXTERN_CREATION_FUNCTION(BitwiseNot); +DML_OP_EXTERN_CREATION_FUNCTION(RotaryEmbedding); DML_OP_EXTERN_QUERY_FUNCTION(MaxPool); DML_OP_EXTERN_QUERY_FUNCTION(Slice); @@ -527,6 +528,7 @@ DML_OP_EXTERN_QUERY_FUNCTION(Attention); constexpr static std::array typeNameListDefault = {"T"}; constexpr static std::array typeNameListDefaultV = {"V"}; constexpr static std::array typeNameListAttention = {"T", "M"}; +constexpr static std::array typeNameListRotaryEmbedding = {"T", "M"}; constexpr static std::array typeNameListTwo = { "T1", "T2" }; constexpr static std::array typeNameListLayerNorm = { "T", "U" }; constexpr static std::array typeNameListLayerNormContrib = { "T", "V" }; @@ -597,6 +599,7 @@ constexpr static std::array supportedTypeListShape constexpr static std::array supportedTypeListSize = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64}; constexpr static std::array supportedTypeListQLinearSigmoid = {SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8}; constexpr static std::array supportedTypeListAttention = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32}; +constexpr static std::array supportedTypeListRotaryEmbedding = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int64}; constexpr static std::array supportedTypeListGroupNorm = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32}; constexpr static std::array supportedTypeListNonZero = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit | SupportedTensorDataTypes::Bool}; @@ -1006,6 +1009,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_MS( 1, QLinearSigmoid, typeNameListDefault, supportedTypeListQLinearSigmoid, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQLinearSigmoid)}, {REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)}, {REG_INFO_MS( 1, MultiHeadAttention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported)}, + {REG_INFO_MS( 1, RotaryEmbedding, typeNameListRotaryEmbedding, supportedTypeListRotaryEmbedding, DmlGraphSupport::Supported)}, {REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)}, {REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index dac128f92a..e9591cfce6 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -122,6 +122,7 @@ namespace AttrName static constexpr const char* GraphFusedActivation = "activation"; static constexpr const char* GraphFusedAxis = "activation_axis"; + static constexpr const char* Interleaved = "interleaved"; } // namespace AttrName diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 485e20c1df..f7e545d9d9 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1584,6 +1584,7 @@ using ShapeInferenceHelper_DequantizeLinear = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Attention = AttentionHelper; using ShapeInferenceHelper_MultiHeadAttention = MultiHeadAttentionHelper; +using ShapeInferenceHelper_RotaryEmbedding = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Sign = GetBroadcastedOutputShapeHelper; using ShapeInferenceHelper_IsNaN = GetBroadcastedOutputShapeHelper; using ShapeInferenceHelper_Erf = GetBroadcastedOutputShapeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index c1e525400b..e18ba31def 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -437,6 +437,7 @@ namespace OperatorHelper static const int sc_sinceVer_BiasAdd = 1; static const int sc_sinceVer_QuickGelu = 1; static const int sc_sinceVer_GroupNorm = 1; + static const int sc_sinceVer_RotaryEmbedding = 1; } // namespace MsftOperatorSet1 } // namespace OperatorHelper diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index a46f820c62..021e7e3adb 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -27,8 +27,12 @@ namespace onnxruntime { struct DMLProviderFactory : IExecutionProviderFactory { DMLProviderFactory(IDMLDevice* dml_device, - ID3D12CommandQueue* cmd_queue) : dml_device_(dml_device), - cmd_queue_(cmd_queue) {} + ID3D12CommandQueue* cmd_queue, + bool disable_metacommands, + bool enable_dynamic_graph_fusion) : dml_device_(dml_device), + cmd_queue_(cmd_queue), + metacommands_enabled_(!disable_metacommands), + dynamic_graph_fusion_enabled_(enable_dynamic_graph_fusion) {} ~DMLProviderFactory() override {} std::unique_ptr CreateProvider() override; @@ -39,10 +43,11 @@ struct DMLProviderFactory : IExecutionProviderFactory { ComPtr dml_device_{}; ComPtr cmd_queue_{}; bool metacommands_enabled_ = true; + bool dynamic_graph_fusion_enabled_ = false; }; std::unique_ptr DMLProviderFactory::CreateProvider() { - auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), cmd_queue_.Get(), metacommands_enabled_); + auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), cmd_queue_.Get(), metacommands_enabled_, dynamic_graph_fusion_enabled_); return provider; } @@ -51,7 +56,9 @@ void DMLProviderFactory::SetMetacommandsEnabled(bool metacommands_enabled) { } std::shared_ptr CreateExecutionProviderFactory_DML(IDMLDevice* dml_device, - ID3D12CommandQueue* cmd_queue) { + ID3D12CommandQueue* cmd_queue, + bool disable_metacommands, + bool enable_dynamic_graph_fusion) { #ifndef _GAMING_XBOX // Validate that the D3D12 devices match between DML and the command queue. This specifically asks for IUnknown in // order to be able to compare the pointers for COM object identity. @@ -70,7 +77,7 @@ std::shared_ptr CreateExecutionProviderFactory_DML(ID const Env& env = Env::Default(); auto luid = d3d12_device->GetAdapterLuid(); env.GetTelemetryProvider().LogExecutionProviderEvent(&luid); - return std::make_shared(dml_device, cmd_queue); + return std::make_shared(dml_device, cmd_queue, disable_metacommands, enable_dynamic_graph_fusion); } void DmlConfigureProviderFactoryMetacommandsEnabled(IExecutionProviderFactory* factory, bool metacommandsEnabled) { @@ -92,8 +99,43 @@ bool IsSoftwareAdapter(IDXGIAdapter1* adapter) { return isSoftwareAdapter || (isBasicRenderDriverVendorId && isBasicRenderDriverDeviceId); } -std::shared_ptr DMLProviderFactoryCreator::Create(int device_id) { - return Create(device_id, /*skip_software_device_check*/ false); +static std::optional ParseDeviceId(const ProviderOptions& provider_options) { + static const std::string DeviceId = "device_id"; + + auto preference_it = provider_options.find(DeviceId); + if (preference_it != provider_options.end()) { + if (!preference_it->second.empty()) { + return std::stoi(preference_it->second); + } + } + + return {}; +} + +static bool ParseBoolean(const ProviderOptions& provider_options, const std::string& key) { + auto preference_it = provider_options.find(key); + if (preference_it != provider_options.end() && !preference_it->second.empty()) { + if (preference_it->second == "True" || preference_it->second == "true") { + return true; + } else if (preference_it->second == "False" || preference_it->second == "false") { + return false; + } else { + ORT_THROW("[ERROR] [DirectML] The value for the key '" + key + "' should be 'True' or 'False'. Default value is 'False'.\n"); + } + } + + return false; +} + +std::shared_ptr DMLProviderFactoryCreator::CreateFromProviderOptions( + const ProviderOptions& provider_options) { + + bool disable_metacommands = ParseBoolean(provider_options, "disable_metacommands"); + bool enable_dynamic_graph_fusion = ParseBoolean(provider_options, "enable_dynamic_graph_fusion"); + bool skip_software_device_check = false; + auto device_id = ParseDeviceId(provider_options); + + return onnxruntime::DMLProviderFactoryCreator::Create(device_id.value_or(0), skip_software_device_check, disable_metacommands, enable_dynamic_graph_fusion); } Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateD3D12Device(int device_id, bool skip_software_device_check) @@ -128,21 +170,13 @@ Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateD3D12Devic return d3d12_device; } -std::shared_ptr DMLProviderFactoryCreator::Create(int device_id, bool skip_software_device_check) { - ComPtr d3d12_device = CreateD3D12Device(device_id, skip_software_device_check); - - D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {}; - cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; - cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT; - - ComPtr cmd_queue; - ORT_THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_GRAPHICS_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf()))); - +Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateDMLDevice(ID3D12Device* d3d12_device) +{ DML_CREATE_DEVICE_FLAGS flags = DML_CREATE_DEVICE_FLAG_NONE; // In debug builds, enable the DML debug layer if the D3D12 debug layer is also enabled #if _DEBUG && !_GAMING_XBOX - ComPtr debug_device; + Microsoft::WRL::ComPtr debug_device; (void)d3d12_device->QueryInterface(IID_PPV_ARGS(&debug_device)); // ignore failure const bool is_d3d12_debug_layer_enabled = (debug_device != nullptr); @@ -151,13 +185,32 @@ std::shared_ptr DMLProviderFactoryCreator::Create(int } #endif - ComPtr dml_device; - ORT_THROW_IF_FAILED(DMLCreateDevice1(d3d12_device.Get(), - flags, - DML_FEATURE_LEVEL_5_0, - IID_PPV_ARGS(&dml_device))); + Microsoft::WRL::ComPtr dml_device; + ORT_THROW_IF_FAILED(DMLCreateDevice1( + d3d12_device, + flags, + DML_FEATURE_LEVEL_5_0, + IID_PPV_ARGS(&dml_device))); + + return dml_device; +} + +std::shared_ptr DMLProviderFactoryCreator::Create( + int device_id, + bool skip_software_device_check, + bool disable_metacommands, + bool enable_dynamic_graph_fusion) { + ComPtr d3d12_device = CreateD3D12Device(device_id, skip_software_device_check); + + D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {}; + cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; + cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT; + + ComPtr cmd_queue; + ORT_THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_GRAPHICS_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf()))); - return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get()); + auto dml_device = CreateDMLDevice(d3d12_device.Get()); + return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get(), disable_metacommands, enable_dynamic_graph_fusion); } } // namespace onnxruntime @@ -167,7 +220,7 @@ std::shared_ptr DMLProviderFactoryCreator::Create(int // The OrtSessionOptionsAppendExecutionProvider_DML export on the OrtDmlApi should be used instead. ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOptions* options, int device_id) { API_IMPL_BEGIN - options->provider_factories.push_back(onnxruntime::DMLProviderFactoryCreator::Create(device_id)); + options->provider_factories.push_back(onnxruntime::DMLProviderFactoryCreator::Create(device_id, false, false, false)); API_IMPL_END return nullptr; } @@ -179,7 +232,9 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSess _In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue) { API_IMPL_BEGIN options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_DML(dml_device, - cmd_queue)); + cmd_queue, + false, + false)); API_IMPL_END return nullptr; } diff --git a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h index b1c9bb3f6f..e136f13ff9 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h +++ b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h @@ -7,14 +7,23 @@ #include #include +#include "core/framework/provider_options.h" #include "core/providers/providers.h" #include "core/providers/dml/dml_provider_factory.h" namespace onnxruntime { struct DMLProviderFactoryCreator { - static std::shared_ptr Create(int device_id); - static std::shared_ptr Create(int device_id, bool skip_software_device_check); + static std::shared_ptr Create( + int device_id, + bool skip_software_device_check, + bool disable_metacommands, + bool enable_dynamic_graph_fusion); + + static std::shared_ptr CreateFromProviderOptions( + const ProviderOptions& provider_options_map); + static Microsoft::WRL::ComPtr CreateD3D12Device(int device_id, bool skip_software_device_check); + static Microsoft::WRL::ComPtr CreateDMLDevice(ID3D12Device* d3d12_device); }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index d9c6126d4b..3fe1980141 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -231,7 +231,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnn class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, ConvTranspose); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, float, Gemm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Gemm); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Gemm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul); @@ -464,7 +465,8 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/gemm.cc b/onnxruntime/core/providers/js/operators/gemm.cc index f579d62bdf..04700d0f54 100644 --- a/onnxruntime/core/providers/js/operators/gemm.cc +++ b/onnxruntime/core/providers/js/operators/gemm.cc @@ -12,7 +12,15 @@ namespace js { ONNX_OPERATOR_TYPED_KERNEL_EX( \ Gemm, \ kOnnxDomain, \ - 11, \ + 13, \ + T, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Gemm); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Gemm, \ + kOnnxDomain, \ + 11, 12, \ T, \ kJsExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ diff --git a/onnxruntime/core/providers/js/operators/split.h b/onnxruntime/core/providers/js/operators/split.h index 691af48711..cfacc1aa6a 100644 --- a/onnxruntime/core/providers/js/operators/split.h +++ b/onnxruntime/core/providers/js/operators/split.h @@ -25,8 +25,9 @@ class Split : public JsKernel, public SplitBase { if (num_outputs_ < 0) { num_outputs_ = split_sizes.size(); } - } else if (split_sizes_.size() == 0) { - // Compute split_sizes from input shape and num_outputs + } else if (split_sizes_.size() == 0 && info.GetInputCount() < 2) { + // Compute split_sizes from input shape and num_outputs. + // TODO: Shape might not be known at this point, better to handle this in javascript auto total_split_size = info.node().InputDefs()[0]->Shape()->dim(gsl::narrow_cast(axis_)).dim_value(); int64_t split_size_sum = 0; if (num_outputs_ < 0) { @@ -44,6 +45,7 @@ class Split : public JsKernel, public SplitBase { ORT_ENFORCE(split_size_sum == total_split_size, "Sum of split sizes (", split_size_sum, ") does not match input size (", total_split_size, ")"); } + // else: let javascript handle all other cases, ie. split_sizes come as input[1] JSEP_INIT_KERNEL_ATTRIBUTE(Split, ({"axis" : $1, "numOutputs" : $2, diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.cc index 5b5ff0f287..7797e0a47c 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.cc @@ -112,6 +112,9 @@ bool BaseOpBuilder::HasSupportedInputOutputs(const InitializedTensorSet& initial }; for (const auto& input : node_unit.Inputs()) { + if (!input.node_arg.Exists()) { + continue; + } if (!has_supported_shape(input.node_arg, node_unit.Name(), node_unit.OpType())) return false; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reduction_op_builder.cc index 618779f6d2..8d0347673b 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reduction_op_builder.cc @@ -51,10 +51,11 @@ void ReductionOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, cons Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { const auto& op_type(node_unit.OpType()); const auto& inputs = node_unit.Inputs(); + const auto& input = node_unit.Inputs()[0].node_arg.Name(); const auto& output = node_unit.Outputs()[0].node_arg.Name(); auto& shaper(model_builder.GetShaper()); - const auto input_shape = shaper[inputs[0].node_arg.Name()]; + const auto input_shape = shaper[input]; const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); @@ -99,10 +100,10 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co } // Add ReduceMean operation - InlinedVector input_indices; - input_indices.push_back(operand_indices.at(inputs[0].node_arg.Name())); // data - if (!axes.empty()) { + InlinedVector input_indices; + input_indices.push_back(operand_indices.at(input)); // data + const auto axes_name = model_builder.GetUniqueName(node_unit.Name() + inputs[0].node_arg.Name() + "_axes"); Shape axes_dimen = {static_cast(axes.size())}; const OperandType axes_operand_type(Type::TENSOR_INT32, axes_dimen); @@ -110,17 +111,17 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co input_indices.push_back(operand_indices.at(axes_name)); // axes - int32_t input_size = static_cast(input_shape.size()); + int32_t input_rank = static_cast(input_shape.size()); // Make output dimensions InlinedVector output_dimen; if (keepdims) { - output_dimen.reserve(input_size); + output_dimen.reserve(input_rank); } else { - output_dimen.reserve(input_size - axes.size()); + output_dimen.reserve(input_rank - axes.size()); } - for (int32_t i = 0; i < input_size; i++) { + for (int32_t i = 0; i < input_rank; i++) { if (std::find(axes.begin(), axes.end(), i) == axes.end()) { output_dimen.push_back(input_shape[i]); } else { @@ -143,10 +144,14 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co ORT_RETURN_IF_ERROR(model_builder.AddOperation(op_code, input_indices, {output}, {output_operand_type})); } else { - // If `axes` is still empty at this point, meaning that it's ReduceMean-18 and attribute `noop_with_empty_axes` specifies as 1, - // treat as an Identity op here. - const OperandType output_operand_type(operand_types.at(inputs[0].node_arg.Name()).type, input_shape); - model_builder.RegisterOperand(output, operand_indices.at(inputs[0].node_arg.Name()), output_operand_type); + // Note: If `axes` is still empty at this point, meaning it's ReduceMean-18 and attribute `noop_with_empty_axes` + // specifies as 1. We treat this case as an Identity op in NNAPI EP. + // However, we hit an issue while adding no-ops operation in NNAPI because it doesn't allow adding an operand both as + // an input and output. + // Currently, we return not supported in NNAPI EP when `noop_with_empty_axes` is true. + + // const OperandType output_operand_type(operand_types.at(input).type, input_shape); + // model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type); } return Status::OK(); @@ -169,6 +174,8 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ const auto& inputs = node_unit.Inputs(); const auto& op(node_unit.OpType()); + NodeAttrHelper helper(node_unit); + Shape input_shape; if (!GetShape(inputs[0].node_arg, input_shape)) return false; @@ -180,6 +187,7 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ } if (op == "ReduceMean") { + const bool noop_with_empty_axes = helper.Get("noop_with_empty_axes", 0) != 0; if (inputs.size() > 1 && inputs[1].node_arg.Exists()) { const auto& axes_name = inputs[1].node_arg.Name(); if (!Contains(initializers, axes_name)) { @@ -187,6 +195,15 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ return false; } } + // Note: For the case - ReduceMean 18+ with noop_with_empty_axes attribute set as 1, + // currently we hit an issue in NNAPI where it does not allow adding an operand as both an input and output. + // This issue may arise from handling no-ops like Identity and ReduceX with noop_with_empty_axes set. + // TODO: Support the case when a more complete solution is available. + if (node_unit.SinceVersion() >= 18 && noop_with_empty_axes) { + LOGS_DEFAULT(VERBOSE) + << "ReduceMean 18+ with noop_with_empty_axes attribute set as 1 is not supported for now."; + return false; + } } return true; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc index 01e348caf1..cdaa1c8fac 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc @@ -153,10 +153,10 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) return false; - const auto input_size = input_shape.size(); - if (input_size != 4) { + const auto input_rank = input_shape.size(); + if (input_rank != 4) { LOGS_DEFAULT(VERBOSE) << "Resize only support 4d shape, input is " - << input_size << "d shape"; + << input_rank << "d shape"; return false; } @@ -206,6 +206,26 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return false; } } + + // The new feature - antialiasing introduced since opset 18 doesn't have a NNAPI mapping support yet. + // And a few other new attributes are currently not handled by NNAPI EP, can add support in the future if needed. + if (node_unit.SinceVersion() >= 18) { + const auto antialias = helper.Get("antialias", 0); + const auto axes = helper.Get("axes", std::vector{}); + const auto keep_aspect_ratio_policy = helper.Get("keep_aspect_ratio_policy", "stretch"); + if (antialias != 0) { + LOGS_DEFAULT(VERBOSE) << "Resize 18+ antialias feature is not currently supported by NNAPI."; + return false; + } + if (!axes.empty()) { + LOGS_DEFAULT(VERBOSE) << "Resize 18+ axes attribute is not currently supported by NNAPI EP."; + return false; + } + if (keep_aspect_ratio_policy != "stretch") { + LOGS_DEFAULT(VERBOSE) << "Resize 18+ keep_aspect_ratio_policy attribute is not currently supported by NNAPI EP."; + return false; + } + } } { // scales and sizes (if present) must be initializers @@ -216,20 +236,22 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers } // scales - if (inputs.size() == 3 && !Contains(initializers, inputs[2].node_arg.Name())) { + bool using_scales = (inputs.size() > 2 && inputs[2].node_arg.Exists()); + if (using_scales && !Contains(initializers, inputs[2].node_arg.Name())) { LOGS_DEFAULT(VERBOSE) << "Input scales of Resize must be known"; return false; } // sizes - if (inputs.size() > 3 && !Contains(initializers, inputs[3].node_arg.Name())) { + bool using_sizes = inputs.size() > 3 && inputs[3].node_arg.Exists(); + if (using_sizes && !Contains(initializers, inputs[3].node_arg.Name())) { LOGS_DEFAULT(VERBOSE) << "Input sizes of Resize must be known"; return false; } bool input_is_nchw = false; // haven't a good solution to check layout when scale is 1.0F // We want to check if the scales or sizes are not trying to resize on N/C channels here - if (inputs.size() == 3) { // we are using scales + if (using_scales) { // we are using scales const auto& scales_tensor = *initializers.at(inputs[2].node_arg.Name()); Initializer const unpacked_tensor(scales_tensor); auto scales_data = unpacked_tensor.DataAsSpan(); diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh index 5c516aac65..429ceb1f7c 100644 --- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh +++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh @@ -250,6 +250,18 @@ __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; } template __device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; } +template +__device__ __inline__ T _Signum(T a, std::false_type /* is_signed */) { return T(0) < a; } + +template +__device__ __inline__ T _Signum(T a, std::true_type /* is_signed */) { return (T(0) < a) - (a < T(0)); } + +template +__device__ __inline__ T _Sign(T a) { return _Signum(a, std::is_signed()); } + +template <> +__device__ __inline__ half _Sign(half a) { return _Signum(a, std::true_type()); } + template __device__ __inline__ T _Normcdf(T a); @@ -337,7 +349,7 @@ struct GridDim { }; // aligned vector generates vectorized load/store -template +template struct alignas(sizeof(T) * vec_size) aligned_vector { T val[vec_size]; }; @@ -350,11 +362,11 @@ struct alignas(sizeof(T) * vec_size) aligned_vector { // HIP_KERNEL_ASSERT is a macro that wraps an assert() call inside rocm kernels. // TODO ROCM added support recently, should verify. #define HIP_KERNEL_ASSERT(...) -//#define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__) +// #define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__) // WARP related definitions and functions constexpr int GPU_WARP_SIZE = warpSize; -inline int GPU_WARP_SIZE_HOST= warpSizeDynamic(); +inline int GPU_WARP_SIZE_HOST = warpSizeDynamic(); template __device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = GPU_WARP_SIZE, unsigned int mask = 0xffffffff) { diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 9401de6426..e6ea876d89 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1105,6 +1105,17 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int16_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint16_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint32_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign); // OpSet 14 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, CumSum); @@ -2067,6 +2078,17 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // OpSet 14 BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/shared/node_unit/node_unit.cc b/onnxruntime/core/providers/shared/node_unit/node_unit.cc index 10dd58ba28..8e73563f9d 100644 --- a/onnxruntime/core/providers/shared/node_unit/node_unit.cc +++ b/onnxruntime/core/providers/shared/node_unit/node_unit.cc @@ -24,6 +24,7 @@ enum class QLinearOpType : uint8_t { QLinearConcat, QLinearGlobalAveragePool, QLinearLeakyRelu, + QLinearConvTranspose, }; QLinearOpType GetQLinearOpType(const onnxruntime::Node& node) { @@ -52,6 +53,8 @@ QLinearOpType GetQLinearOpType(const onnxruntime::Node& node) { return QLinearOpType::QLinearGlobalAveragePool; else if (op_type == "QLinearLeakyRelu") return QLinearOpType::QLinearLeakyRelu; + else if (op_type == "QLinearConvTranspose") + return QLinearOpType::QLinearConvTranspose; return QLinearOpType::Unknown; } @@ -70,7 +73,8 @@ bool IsBinaryQLinearOp(QLinearOpType type) { return type == QLinearOpType::QLinearConv || type == QLinearOpType::QLinearMatMul || type == QLinearOpType::QLinearAdd || - type == QLinearOpType::QLinearMul; + type == QLinearOpType::QLinearMul || + type == QLinearOpType::QLinearConvTranspose; } // Ops have 1 or more inputs diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index f0ab7869b7..c0b282b202 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#pragma once namespace onnxruntime { extern ProviderHost* g_host; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 3f0cfdac8a..ac92d46ca8 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -792,6 +792,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv if (info.has_user_compute_stream) { external_stream_ = true; stream_ = static_cast(info.user_compute_stream); + ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasCreate(&external_cublas_handle_))); + ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasSetStream(external_cublas_handle_, stream_))); + ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnCreate(&external_cudnn_handle_))); + ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnSetStream(external_cudnn_handle_, stream_))); } std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes; @@ -1034,10 +1038,6 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv int8_calibration_cache_available_ = !int8_calibration_cache_name_.empty(); } - if (cuda_graph_enable_) { - GetPerThreadContext().InitCUDAGraph(); - } - /* * Parse explicit min/max/opt profile shapes from provider options. * @@ -1142,50 +1142,35 @@ bool TensorrtExecutionProvider::IsGraphCaptureEnabled() const { return cuda_graph_enable_; } -bool TensorrtExecutionProvider::IsGraphCaptured() const { - return GetPerThreadContext().IsGraphCaptured(); -} - -Status TensorrtExecutionProvider::ReplayGraph() { - return GetPerThreadContext().ReplayGraph(); -} - -void TensorrtExecutionProvider::PerThreadContext::InitCUDAGraph() { - cuda_graph_ = std::make_unique(); -} - -void TensorrtExecutionProvider::PerThreadContext::SetGraphStream(cudaStream_t stream) { - cuda_graph_->SetStream(stream); -} - -bool TensorrtExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const { +bool TensorrtExecutionProvider::IsGraphCaptureAllowed() const { return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; } -void TensorrtExecutionProvider::PerThreadContext::CaptureBegin() { - cuda_graph_->Reset(); - cuda_graph_->CaptureBegin(); +void TensorrtExecutionProvider::CaptureBegin() { + cuda_graph_.Reset(); + cuda_graph_.CaptureBegin(); } -void TensorrtExecutionProvider::PerThreadContext::CaptureEnd() { - cuda_graph_->CaptureEnd(); +void TensorrtExecutionProvider::CaptureEnd() { + cuda_graph_.CaptureEnd(); is_graph_captured_ = true; } -bool TensorrtExecutionProvider::PerThreadContext::IsGraphCaptured() const { +bool TensorrtExecutionProvider::IsGraphCaptured() const { return is_graph_captured_; } -Status TensorrtExecutionProvider::PerThreadContext::ReplayGraph() { +Status TensorrtExecutionProvider::ReplayGraph() { ORT_ENFORCE(IsGraphCaptured()); // Please note that CUDAGraph::Replay() is not thread safe. - // The cuda graph object is maintained by a per thread basis, + // ORT TRT calls ReplayGraph() in compute_func() where synchromization is enforced due to lock_guard(), // therefore calling CUDAGraph::Replay() here is guaranteed to be thread safe. - return cuda_graph_->Replay(); + return cuda_graph_.Replay(); } -void TensorrtExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() { - // The cuda graph object is maintained by a per thread basis, +void TensorrtExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { + // Please note that this function is not thread safe. + // ORT TRT calls this function in compute_func() where synchronization is enforced due to lock_guard(), // therefore following increment is guaranteed to be thread safe. ++regular_run_count_before_graph_capture_; } @@ -1216,18 +1201,6 @@ Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) { if (sync_stream && external_stream_) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); } - - // The reason of !IsGraphCaptureEnabled(): - // If cuda graph is enabled, the per thread context will not be released - // because the per thread cuda graph needs to be maintained and replayed for - // the next run. - // The reason of PerThreadContextCache()->find(this) != PerThreadContextCache()->end(): - // In extreme cases (e.g., 1-op graph and that op fallbacks to CPU), - // PerThreadContext won't be created and there is nothing to release. - if (!IsGraphCaptureEnabled() && - PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) { - ReleasePerThreadContext(); - } return Status::OK(); } @@ -1891,6 +1864,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } else if (number_of_trt_nodes == number_of_ort_nodes) { LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; } else { + sync_stream_after_enqueue_ = true; LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs; } @@ -2381,6 +2355,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorallocate_func, context->release_func, context->allocator_handle, context->node_name, - &parsers_[context->node_name], &engines_[context->node_name], &builders_[context->node_name], + &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], - input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, + input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_, dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_, @@ -2430,13 +2397,20 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector(state); + + // The whole compute_function should be considered the critical section where multiple threads may update kernel function state, access one builder, create/serialize/save engine, + // save profile and serialize/save timing cache. Therefore, those operations should be synchronized across different threads when ORT is using multithreading. + // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); const std::unordered_map& input_indexes = (trt_state->input_info)[0]; const std::unordered_map& output_indexes = (trt_state->output_info)[0]; const std::unordered_map& output_types = (trt_state->output_info)[1]; + bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue; auto fused_node_name = trt_state->fused_node_name; auto& shape_ranges = trt_state->input_shape_ranges; auto trt_builder = trt_state->builder->get(); auto trt_engine = trt_state->engine->get(); + auto trt_context = trt_state->context->get(); auto trt_profiles = trt_state->profiles; auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; int num_inputs = static_cast(input_indexes.size()); @@ -2471,260 +2445,238 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector lock(*(trt_state->tensorrt_mu_ptr)); - - // Load serialized engine - if (trt_state->engine_cache_enable && trt_engine == nullptr) { - std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); - std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in); - if (engine_file && profile_file) { - // Deserialize profile - shape_ranges = DeserializeProfileV2(profile_file); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; - - // Prepare buffer - engine_file.seekg(0, std::ios::end); - size_t engine_size = engine_file.tellg(); - engine_file.seekg(0, std::ios::beg); - std::unique_ptr engine_buf{new char[engine_size]}; - engine_file.read((char*)engine_buf.get(), engine_size); - - // Deserialize engine - // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc - // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - trt_state->engine->reset(); - *(trt_state->engine) = std::unique_ptr( - trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); - if (*(trt_state->engine) == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); - } - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; - trt_engine = trt_state->engine->get(); - context_update = true; - } else if (trt_state->engine_decryption_enable && !engine_file && profile_file) { - shape_ranges = DeserializeProfileV2(profile_file); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; - // Decrypt engine - size_t engine_size = 0; - if (!trt_state->engine_decryption(engine_cache_path.c_str(), nullptr, &engine_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not get engine buffer size"); - } - std::unique_ptr engine_buf{new char[engine_size]}; - if (!trt_state->engine_decryption(engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not call engine decryption function decrypt"); - } - // Deserialize engine - // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc - // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - trt_state->engine->reset(); - *(trt_state->engine) = std::unique_ptr(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); - if (*(trt_state->engine) == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path); - } - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; - trt_engine = trt_state->engine->get(); - context_update = true; - } - } + // Load serialized engine + if (trt_state->engine_cache_enable && trt_engine == nullptr) { + std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); + std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in); + if (engine_file && profile_file) { + // Deserialize profile + shape_ranges = DeserializeProfileV2(profile_file); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; - // Check and update shape ranges for dynamic shape inputs. - for (int i = 0, end = num_inputs; i < end; ++i) { - auto input = trt_state->network->get()->getInput(i); - const std::string& input_name = input->getName(); - input_names.insert(input_name); + // Prepare buffer + engine_file.seekg(0, std::ios::end); + size_t engine_size = engine_file.tellg(); + engine_file.seekg(0, std::ios::beg); + std::unique_ptr engine_buf{new char[engine_size]}; + engine_file.read((char*)engine_buf.get(), engine_size); - // If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile shape values have not yet resolved. - // TRT EP will help determine the min/max/opt profile values based on current input tensor value. - if (shape_ranges.find(input_name) != shape_ranges.end()) { - auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, tensor_shape_values, stream, &engine_update); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to parse input tensor and generate optimization profiles."); - } + // Deserialize engine + // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + trt_state->engine->reset(); + *(trt_state->engine) = std::unique_ptr( + trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); + if (!(*(trt_state->engine))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); } - } - - // Regenerate engine - if (engine_update) { - // Destroy the IExecutionContext objects before destroying an engine object, otherwise it will lead to undefined behavior. - if (GetPerThreadContext().IsTensorRTContextInMap(fused_node_name)) { - GetPerThreadContext().ResetTensorRTContext(fused_node_name); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + trt_engine = trt_state->engine->get(); + context_update = true; + } else if (trt_state->engine_decryption_enable && !engine_file && profile_file) { + shape_ranges = DeserializeProfileV2(profile_file); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; + // Decrypt engine + size_t engine_size = 0; + if (!trt_state->engine_decryption(engine_cache_path.c_str(), nullptr, &engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not get engine buffer size"); } - + std::unique_ptr engine_buf{new char[engine_size]}; + if (!trt_state->engine_decryption(engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not call engine decryption function decrypt"); + } + // Deserialize engine + // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading trt_state->engine->reset(); - auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); - trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr)); - for (auto trt_profile : trt_profiles) { - trt_config->addOptimizationProfile(trt_profile); + *(trt_state->engine) = std::unique_ptr(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); + if (!(*(trt_state->engine))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path); } + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + trt_engine = trt_state->engine->get(); + context_update = true; + } + } - // Set INT8 Per Tensor Dynamic range - if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) { - trt_config->setInt8Calibrator(nullptr); - if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); - } - } + // Check and update shape ranges for dynamic shape inputs. + for (int i = 0, end = num_inputs; i < end; ++i) { + auto input = trt_state->network->get()->getInput(i); + const std::string& input_name = input->getName(); + input_names.insert(input_name); - // Set precision - if (trt_state->fp16_enable && trt_state->int8_enable) { - trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); - } else if (trt_state->fp16_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); - } else if (trt_state->int8_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); + // If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile shape values have not yet resolved. + // TRT EP will help determine the min/max/opt profile values based on current input tensor value. + if (shape_ranges.find(input_name) != shape_ranges.end()) { + auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, tensor_shape_values, stream, &engine_update); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to parse input tensor and generate optimization profiles."); } + } + } - // Set DLA (DLA can only run with FP16 or INT8) - if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core; - trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); - trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); - trt_config->setDLACore(trt_state->dla_core); - } + // Regenerate engine + if (engine_update) { + // Destroy the IExecutionContext objects before destroying an engine object, otherwise it will lead to undefined behavior. + trt_state->context->reset(); + trt_state->engine->reset(); + auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); + trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr)); + for (auto trt_profile : trt_profiles) { + trt_config->addOptimizationProfile(trt_profile); + } - // enable sparse weights - if (trt_state->sparsity_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; + // Set INT8 Per Tensor Dynamic range + if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) { + trt_config->setInt8Calibrator(nullptr); + if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); } + } - // enable builder heuristics - if (trt_state->build_heuristics_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled"; - } + // Set precision + if (trt_state->fp16_enable && trt_state->int8_enable) { + trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); + } else if (trt_state->fp16_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); + } else if (trt_state->int8_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); + } + + // Set DLA (DLA can only run with FP16 or INT8) + if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core; + trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); + trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); + trt_config->setDLACore(trt_state->dla_core); + } + + // enable sparse weights + if (trt_state->sparsity_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; + } + + // enable builder heuristics + if (trt_state->build_heuristics_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled"; + } #if NV_TENSORRT_MINOR > 5 && NV_TENSORRT_MAJOR >= 8 - // switch optimizaion level - if (trt_state->builder_optimization_level != 3) { - trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; - } + // switch optimizaion level + if (trt_state->builder_optimization_level != 3) { + trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; + } - // limit auxiliary streams - if (trt_state->auxiliary_streams >= 0) { - trt_config->setMaxAuxStreams(trt_state->auxiliary_streams); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << trt_state->auxiliary_streams; - } + // limit auxiliary streams + if (trt_state->auxiliary_streams >= 0) { + trt_config->setMaxAuxStreams(trt_state->auxiliary_streams); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << trt_state->auxiliary_streams; + } #else - if (trt_state->builder_optimization_level != 3) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; - } - if (trt_state->auxiliary_streams >= 0) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; - } + if (trt_state->builder_optimization_level != 3) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; + } + if (trt_state->auxiliary_streams >= 0) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; + } #endif - // limit used tactic sources - if (trt_state->filter_tactic_sources) { - nvinfer1::TacticSources tactics = trt_config->getTacticSources(); - tactics |= trt_state->tactic_sources; - trt_config->setTacticSources(tactics); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using bitmask " << tactics; + // limit used tactic sources + if (trt_state->filter_tactic_sources) { + nvinfer1::TacticSources tactics = trt_config->getTacticSources(); + tactics |= trt_state->tactic_sources; + trt_config->setTacticSources(tactics); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using bitmask " << tactics; + } + + // Load timing cache from file. Create a fresh cache if the file doesn't exist + std::unique_ptr timing_cache = nullptr; + if (trt_state->timing_cache_enable) { + std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); + timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size())); + if (timing_cache == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not create timing cache: " + timing_cache_path); } - - // Load timing cache from file. Create a fresh cache if the file doesn't exist - std::unique_ptr timing_cache = nullptr; - if (trt_state->timing_cache_enable) { - std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); - timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size())); - if (timing_cache == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not create timing cache: " + timing_cache_path); - } - trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); - if (detailed_build_log_) { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; - } + trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); + if (detailed_build_log_) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; } + } - // Build engine - { - auto lock = GetApiLock(); - std::chrono::steady_clock::time_point engine_build_start; - if (detailed_build_log_) { - engine_build_start = std::chrono::steady_clock::now(); - } - *(trt_state->engine) = std::unique_ptr( - trt_builder->buildEngineWithConfig(*trt_state->network->get(), *trt_config)); - if (detailed_build_log_) { - auto engine_build_stop = std::chrono::steady_clock::now(); - LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; - } + // Build engine + { + auto lock = GetApiLock(); + std::chrono::steady_clock::time_point engine_build_start; + if (detailed_build_log_) { + engine_build_start = std::chrono::steady_clock::now(); } - if (*(trt_state->engine) == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); + *(trt_state->engine) = std::unique_ptr( + trt_builder->buildEngineWithConfig(*trt_state->network->get(), *trt_config)); + if (detailed_build_log_) { + auto engine_build_stop = std::chrono::steady_clock::now(); + LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; } - trt_engine = trt_state->engine->get(); - if (trt_state->engine_cache_enable) { - // Serialize engine profile - SerializeProfileV2(profile_cache_path, shape_ranges); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; + } + if (!(*(trt_state->engine))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); + } + trt_engine = trt_state->engine->get(); + if (trt_state->engine_cache_enable) { + // Serialize engine profile + SerializeProfileV2(profile_cache_path, shape_ranges); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; - // Serialize engine - std::unique_ptr serializedModel(trt_engine->serialize()); - size_t engine_size = serializedModel->size(); - if (trt_state->engine_decryption_enable) { - // Encrypt engine - if (!trt_state->engine_encryption(engine_cache_path.c_str(), reinterpret_cast(serializedModel->data()), engine_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not call engine encryption function encrypt"); - } - } else { - std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); - file.write(reinterpret_cast(serializedModel->data()), engine_size); + // Serialize engine + std::unique_ptr serializedModel(trt_engine->serialize()); + size_t engine_size = serializedModel->size(); + if (trt_state->engine_decryption_enable) { + // Encrypt engine + if (!trt_state->engine_encryption(engine_cache_path.c_str(), reinterpret_cast(serializedModel->data()), engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not call engine encryption function encrypt"); } - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path; + } else { + std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); + file.write(reinterpret_cast(serializedModel->data()), engine_size); } + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path; + } - // serialize and save timing cache - if (trt_state->timing_cache_enable) { - auto timing_cache = trt_config->getTimingCache(); - std::unique_ptr timingCacheHostData{timing_cache->serialize()}; - if (timingCacheHostData == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not serialize timing cache: " + timing_cache_path); - } - saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); - if (detailed_build_log_) { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; - } + // serialize and save timing cache + if (trt_state->timing_cache_enable) { + auto timing_cache = trt_config->getTimingCache(); + std::unique_ptr timingCacheHostData{timing_cache->serialize()}; + if (timingCacheHostData == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not serialize timing cache: " + timing_cache_path); + } + saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); + if (detailed_build_log_) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; } - context_update = true; } + context_update = true; } - // Build execution context if either of the following conditions is true: - // (1) The engine is built or updated by this thread. - // (2) The first inference run for this thread where there is no IExecutionContext object yet. - // (3) The engine is updated by another thread. (We compare the profile shapes maintained by the PerThreadContext to the profile shapes maintained by TRT EP) - // - // Note: Creating an execution context from an engine is thread safe per TRT doc - // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - if (context_update || - !GetPerThreadContext().IsTensorRTContextInMap(fused_node_name) || - GetPerThreadContext().CompareProfileShapes(fused_node_name, shape_ranges)) { - std::unique_ptr new_context; + if (context_update) { if (trt_state->context_memory_sharing_enable) { - new_context.reset(trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); } else { - new_context.reset(trt_state->engine->get()->createExecutionContext()); + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContext()); } - auto context_status = GetPerThreadContext().UpdateTensorRTContext(fused_node_name, std::move(new_context)); - if (!context_status) { + if (!(*(trt_state->context))) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context."); } - GetPerThreadContext().UpdateProfileShapes(fused_node_name, shape_ranges); + trt_context = trt_state->context->get(); } - // Get the reference to the IExecutionContext object that is maintained on a per thread basis. - nvinfer1::IExecutionContext& trt_context = GetPerThreadContext().GetTensorRTContext(fused_node_name); - // Get input and output binding names int total_bindings = trt_engine->getNbBindings(); std::vector buffers(total_bindings); @@ -2760,12 +2712,12 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorisShapeBinding(binding_index)) { - trt_context.setInputShapeBinding(binding_index, &tensor_shape_values[input_name][0]); + trt_context->setInputShapeBinding(binding_index, &tensor_shape_values[input_name][0]); } else { for (int j = 0, end = nb_dims; j < end; ++j) { dimensions.d[j] = static_cast(tensor_shapes[j]); } - const bool status = trt_context.setBindingDimensions(binding_index, dimensions); + const bool status = trt_context->setBindingDimensions(binding_index, dimensions); if (!status) { ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP cannot set the dynamic dimensions of a binding")); @@ -2904,7 +2856,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorsecond; } - nvinfer1::Dims dimensions = trt_context.getBindingDimensions(static_cast(binding_index)); + nvinfer1::Dims dimensions = trt_context->getBindingDimensions(static_cast(binding_index)); int nb_dims = dimensions.nbDims; std::vector output_shapes(nb_dims); for (int j = 0, end = nb_dims; j < end; ++j) { @@ -3038,23 +2990,27 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector *max_context_mem_size_ptr) { *max_context_mem_size_ptr = mem_size; } - trt_context.setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); + trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); } // Start CUDA graph capture. // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. - if (cuda_graph_enable_ && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { + if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) { LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; - GetPerThreadContext().SetGraphStream(stream); - GetPerThreadContext().CaptureBegin(); + cuda_graph_.SetStream(stream); + CaptureBegin(); } // Run TRT inference - if (!trt_context.enqueueV2(&buffers[0], stream, nullptr)) { + if (!trt_context->enqueueV2(&buffers[0], stream, nullptr)) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); } + if (sync_stream_after_enqueue) { + cudaStreamSynchronize(stream); + } + // Cast INT64 input to INT32 because TensorRT doesn't fully support INT64 for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { const std::string& output_name = output_binding_names[i]; @@ -3082,14 +3038,14 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector* parser = nullptr; std::unique_ptr* engine = nullptr; + std::unique_ptr* context = nullptr; std::unique_ptr* builder = nullptr; std::unique_ptr* network = nullptr; std::vector> input_info; std::vector> output_info; std::unordered_map>>> input_shape_ranges; + bool sync_stream_after_enqueue = false; OrtMutex* tensorrt_mu_ptr = nullptr; bool fp16_enable = false; bool int8_enable = false; @@ -246,6 +248,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { // For those non thread safe operations, TRT EP uses (1) lock_guard or (2) PerThreadContext to make sure synchronization. std::unordered_map> parsers_; std::unordered_map> engines_; + std::unordered_map> contexts_; std::unordered_map> builders_; std::unordered_map> networks_; std::unordered_map>> input_info_; @@ -256,6 +259,24 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::unordered_map input_shape_ranges_; // The profile shape ranges that the engine is built with std::unordered_map> profiles_; + // for external stream, we need to create its cudnn/cublass handle before cuda EP enable cuda graph capture + cudnnHandle_t external_cudnn_handle_ = nullptr; + cublasHandle_t external_cublas_handle_ = nullptr; + + // Call cudaStreamSynchronize() after TRT enqueueV2()/enqueueV3() + mutable bool sync_stream_after_enqueue_ = false; + + CUDAGraph cuda_graph_; + bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + // There is chance (currently only happens in CUDA EP) that the second regular run allocates GPU memory for causes like: + // (1) memory pattern is enabled. (2) arena allocation for stream. + // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs + // to allocate enough memory in Arena before graph capturing. + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. + + // [Note] We don't use PerThreadContext for now since it has issue with multithreading + // // TRT or CUDA objects that must be maintained on a per thread basis will be put under this PerThreadContext data structure. // For example, TensorRT execution context and CUDA graph are the ones to be put here. class PerThreadContext final { @@ -306,15 +327,15 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::unordered_map input_shape_ranges_; // Cuda graph with multi threads will be supported in the future, so cuda_graph_ is put under PerThreadContext. - // ORT TRT only supports CUDA graph when whole model is supported by TRT, so simply maintaining a CUDAGraph pointer is enough (no need to maintain one CUDAGraph pointer per TRT subgraph) - std::unique_ptr cuda_graph_; + // ORT TRT only supports CUDA graph when whole model is supported by TRT, so simply maintaining a CUDAGraph instance is enough (no need to maintain one CUDAGraph instance per TRT subgraph) + CUDAGraph cuda_graph_; bool is_graph_captured_ = false; - int regular_run_count_before_graph_capture_ = -1; + int regular_run_count_before_graph_capture_ = 0; // There is chance (currently only happens in CUDA EP) that the second regular run allocates GPU memory for causes like: // (1) memory pattern is enabled. (2) arena allocation for stream. // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs // to allocate enough memory in Arena before graph capturing. - const int min_num_runs_before_cuda_graph_capture_ = 0; // required min regular runs before graph capture for the necessary memory allocations. + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. }; using PerThreadContextMap = std::unordered_map>; diff --git a/onnxruntime/core/providers/vitisai/imp/graph.cc b/onnxruntime/core/providers/vitisai/imp/graph.cc index 4df11c2224..b5f45b15a5 100644 --- a/onnxruntime/core/providers/vitisai/imp/graph.cc +++ b/onnxruntime/core/providers/vitisai/imp/graph.cc @@ -138,6 +138,7 @@ void graph_save(const Graph& graph, const std::string& filename, const std::stri model_proto = model.ToProto(); } else { model_proto = model.ToGraphProtoWithExternalInitializers(filename_dat, + ToPathString(filename), initializer_size_threshold); } auto& metadata = model.MetaData(); diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 80a0cb673c..fc3fd4eb7d 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -234,6 +234,7 @@ Status Environment::Initialize(std::unique_ptr logging_ domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kMSInternalNHWCDomain, 1, onnx_version); domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kPytorchAtenDomain, 1, 1); + domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kQuadricDomain, 1, 1); #ifdef USE_DML domainToVersionRangeInstance.AddDomainToVersion(onnxruntime::kMSDmlDomain, 1, 1); #endif diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 6a70176ebc..067cef4f07 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -52,8 +52,10 @@ #include "core/providers/cpu/cpu_execution_provider.h" #ifdef USE_DML // TODO: This is necessary for the workaround in TransformGraph #include "core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h" #include "core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h" #include "core/providers/dml/dml_session_options_config_keys.h" +#include "core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h" #endif #include "core/session/environment.h" #include "core/session/IOBinding.h" @@ -1531,7 +1533,9 @@ common::Status InferenceSession::Initialize() { record_runtime_optimization_produced_op_schema)); #ifdef USE_DML - if (execution_providers_.Get(kDmlExecutionProvider)) { + const IExecutionProvider* dmlExecutionProvider = execution_providers_.Get(kDmlExecutionProvider); + + if (dmlExecutionProvider) { // DML graph fusion is an important runtime optimization that cannot be done ahead of time; it must be disabled // when running in "offline mode" and saving an optimized model to disk. To support users that want to optimize // models offline, and then disable graph optimizations when running "online", this transformer ignores the ORT @@ -1541,11 +1545,20 @@ common::Status InferenceSession::Initialize() { if (dml_graph_fusion_enabled) { std::unique_ptr dmlGraphFusionTransformer = std::make_unique("DmlGraphFusionTransformer", - execution_providers_.Get(kDmlExecutionProvider)); + dmlExecutionProvider); if (dmlGraphFusionTransformer == nullptr) { return Status(common::ONNXRUNTIME, common::FAIL, "DmlGraphFusionTransformer is nullptr"); } ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.Register(std::move(dmlGraphFusionTransformer), onnxruntime::TransformerLevel::Level3)); + + if (static_cast(dmlExecutionProvider)->DynamicGraphFusionEnabled()) { + std::unique_ptr dmlRuntimeGraphFusionTransformer = std::make_unique("DmlRuntimeGraphFusionTransformer", + dmlExecutionProvider); + if (dmlRuntimeGraphFusionTransformer == nullptr) { + return Status(common::ONNXRUNTIME, common::FAIL, "DmlRuntimeGraphFusionTransformer is nullptr"); + } + ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.Register(std::move(dmlRuntimeGraphFusionTransformer), onnxruntime::TransformerLevel::Level3)); + } } // This transformer applies DML-specific fusions that go beyond what ORT offers by default diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 15fe5acfe0..70d2d0fe5d 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2711,9 +2711,8 @@ static constexpr OrtApi ort_api_1_to_16 = { &OrtApis::GetTensorRTProviderOptionsByName, &OrtApis::UpdateCUDAProviderOptionsWithValue, &OrtApis::GetCUDAProviderOptionsByName, - // End of Version 16 - DO NOT MODIFY ABOVE (see above text for more information) - &OrtApis::KernelContext_GetResource, + // End of Version 16 - DO NOT MODIFY ABOVE (see above text for more information) }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. @@ -2742,10 +2741,10 @@ static_assert(offsetof(OrtApi, ReleaseKernelInfo) / sizeof(void*) == 218, "Size static_assert(offsetof(OrtApi, ReleaseCANNProviderOptions) / sizeof(void*) == 224, "Size of version 13 API cannot change"); static_assert(offsetof(OrtApi, GetSessionConfigEntry) / sizeof(void*) == 238, "Size of version 14 API cannot change"); static_assert(offsetof(OrtApi, GetBuildInfoString) / sizeof(void*) == 254, "Size of version 15 API cannot change"); -static_assert(offsetof(OrtApi, GetCUDAProviderOptionsByName) / sizeof(void*) == 264, "Size of version 16 API cannot change"); +static_assert(offsetof(OrtApi, KernelContext_GetResource) / sizeof(void*) == 265, "Size of version 16 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: -static_assert(std::string_view(ORT_VERSION) == "1.16.0", +static_assert(std::string_view(ORT_VERSION) == "1.16.2", "ORT_Version change detected, please follow below steps to ensure OrtApi is updated properly"); // 1. Update the hardcoded version string in above static_assert to silence it // 2. If there were any APIs added to ort_api_1_to_16 above: diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index 983321593a..135b4bb4c7 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -701,12 +701,12 @@ template struct Im2col; template struct Im2col; template struct Im2col; -template <> -void Col2im(const float* data_col, int64_t channels, int64_t height, +template +void Col2imNCHW(const T* data_col, int64_t channels, int64_t height, int64_t width, int64_t kernel_h, int64_t kernel_w, int64_t dilation_h, int64_t dilation_w, int64_t pad_t, int64_t pad_l, int64_t pad_b, int64_t pad_r, int64_t stride_h, - int64_t stride_w, float* data_im, CPUMathUtil* context) { + int64_t stride_w, T* data_im, Provider* context) { const int64_t output_h = (height + pad_b + pad_t - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; const int64_t output_w = @@ -714,7 +714,7 @@ void Col2im(const float* data_col, int64 const int64_t output_hw = output_h * output_w; const int64_t hw = height * width; const int64_t hwc = hw * channels; - Set(narrow(hwc), 0, data_im, context); + Set(narrow(hwc), 0, data_im, context); // Fast path for zero padding and no dilation // From Torch, modified THNN_(unfolded_acc) @@ -792,6 +792,32 @@ void Col2im(const float* data_col, int64 } } +template <> +void Col2im(const float* data_col, int64_t channels, int64_t height, + int64_t width, int64_t kernel_h, int64_t kernel_w, + int64_t dilation_h, int64_t dilation_w, int64_t pad_t, + int64_t pad_l, int64_t pad_b, int64_t pad_r, int64_t stride_h, + int64_t stride_w, float* data_im, CPUMathUtil* context) { + Col2imNCHW(data_col, channels, height, + width, kernel_h, kernel_w, + dilation_h, dilation_w, pad_t, + pad_l, pad_b, pad_r, stride_h, + stride_w, data_im, context); +} + +template <> +void Col2im(const int32_t* data_col, int64_t channels, int64_t height, + int64_t width, int64_t kernel_h, int64_t kernel_w, + int64_t dilation_h, int64_t dilation_w, int64_t pad_t, + int64_t pad_l, int64_t pad_b, int64_t pad_r, int64_t stride_h, + int64_t stride_w, int32_t* data_im, CPUMathUtil* context) { + Col2imNCHW(data_col, channels, height, + width, kernel_h, kernel_w, + dilation_h, dilation_w, pad_t, + pad_l, pad_b, pad_r, stride_h, + stride_w, data_im, context); +} + template <> void Col2im(const float* data_col, int64_t channels, int64_t height, int64_t width, int64_t kernel_h, int64_t kernel_w, diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index b73fcbbff5..c86469918c 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -446,14 +446,6 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi providers, provider_options = check_and_normalize_provider_args( providers, provider_options, available_providers ) - if not providers and len(available_providers) > 1: - self.disable_fallback() - raise ValueError( - f"This ORT build has {available_providers} enabled. " - "Since ORT 1.9, you are required to explicitly set " - "the providers parameter when instantiating InferenceSession. For example, " - f"onnxruntime.InferenceSession(..., providers={available_providers}, ...)" - ) session_options = self._sess_options if self._sess_options else C.get_default_session_options() if self._model_path: diff --git a/onnxruntime/python/onnxruntime_pybind_iobinding.cc b/onnxruntime/python/onnxruntime_pybind_iobinding.cc index 7638a12bb8..59d5a77bfb 100644 --- a/onnxruntime/python/onnxruntime_pybind_iobinding.cc +++ b/onnxruntime/python/onnxruntime_pybind_iobinding.cc @@ -60,8 +60,6 @@ void addIoBindingMethods(pybind11::module& m) { }) // This binds input as a Tensor that wraps memory pointer along with the OrtMemoryInfo .def("bind_input", [](SessionIOBinding* io_binding, const std::string& name, const OrtDevice& device, py::object& element_type, const std::vector& shape, int64_t data_ptr) -> void { - ORT_ENFORCE(data_ptr != 0, "Pointer to data memory is not valid"); - PyArray_Descr* dtype; if (!PyArray_DescrConverter(element_type.ptr(), &dtype)) { throw std::runtime_error("Not a valid numpy type"); diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 10c8a2de7c..f470e9f6b6 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -26,7 +26,18 @@ #include "core/framework/provider_options_utils.h" #ifdef USE_DML -#include "core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h" +using Microsoft::WRL::ComPtr; + +#include +#include "core/providers/dml/DmlExecutionProvider/src/External/D3DX12/d3dx12.h" +#include "core/providers/dml/DmlExecutionProvider/src/ErrorHandling.h" +#include "core/providers/dml/DmlExecutionProvider/src/DescriptorPool.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.h" +#include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" +#include "core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h" +#include "core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.h" +#include "core/providers/dml/DmlExecutionProvider/src/ReadbackHeap.h" +#include "core/providers/dml/DmlExecutionProvider/src/AllocationInfo.h" #endif namespace onnxruntime { @@ -186,6 +197,11 @@ std::unique_ptr GetGPUDataTransfer() { #endif #ifdef USE_DML + +constexpr GUID execution_context_guid = {0x50fd773b, 0x4462, 0x4b28, {0x98, 0x9e, 0x8c, 0xa0, 0x54, 0x05, 0xbd, 0x4a}}; +constexpr GUID upload_heap_guid = {0x125235f9, 0xef41, 0x4043, {0xa4, 0x9d, 0xdd, 0xc9, 0x61, 0xe7, 0xdb, 0xee}}; +constexpr GUID dml_readback_heap_guid = {0x00d32df8, 0xea2d, 0x40bf, {0xa4, 0x47, 0x9c, 0xb4, 0xbc, 0xf1, 0x1d, 0x5e}}; + AllocatorPtr GetDmlAllocator(OrtDevice::DeviceId id) { // Current approach is not thread-safe, but there are some bigger infra pieces to put together in order to make // multi-threaded DML allocation work, including maintaining a per-thread DML allocator. @@ -196,13 +212,100 @@ AllocatorPtr GetDmlAllocator(OrtDevice::DeviceId id) { auto hit = id_to_allocator_map->find(id); if (hit == id_to_allocator_map->end()) { - auto dml_allocator = std::make_shared(id); + constexpr uint32_t device_id = 0; + auto d3d12_device = onnxruntime::DMLProviderFactoryCreator::CreateD3D12Device(device_id, false); + auto dml_device = onnxruntime::DMLProviderFactoryCreator::CreateDMLDevice(d3d12_device.Get()); + + D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {}; + cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; + cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT; + + ComPtr cmd_queue; + ORT_THROW_IF_FAILED( + d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf()))); + + auto context = std::make_shared(d3d12_device.Get(), dml_device.Get(), cmd_queue.Get()); + + // We leak the upload and readback heaps to keep them alive, just like the map + auto upload_heap = std::make_unique(d3d12_device.Get(), context).release(); + auto readback_heap = std::make_unique(d3d12_device.Get(), context).release(); + + auto dml_allocator = std::make_shared( + d3d12_device.Get(), + context, + CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT), + D3D12_HEAP_FLAG_NONE, + D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + std::make_unique(d3d12_device.Get())); + dml_allocator->SetDefaultRoundingMode(AllocatorRoundingMode::Enabled); + context->SetAllocator(dml_allocator); + + auto context_ptr = context.get(); + + ORT_THROW_IF_FAILED(d3d12_device->SetPrivateData(execution_context_guid, sizeof(context_ptr), &context_ptr)); + ORT_THROW_IF_FAILED(d3d12_device->SetPrivateData(upload_heap_guid, sizeof(upload_heap), &upload_heap)); + ORT_THROW_IF_FAILED(d3d12_device->SetPrivateData(dml_readback_heap_guid, sizeof(readback_heap), &readback_heap)); + hit = id_to_allocator_map->emplace(id, std::move(dml_allocator)).first; } return hit->second; } +void CpuToDmlMemCpy(void* dst, const void* src, size_t num_bytes) { + const auto* allocInfo = static_cast(dst); + ID3D12Resource* dst_data = allocInfo->GetResource(); + + ComPtr d3d12_device; + ORT_THROW_IF_FAILED(dst_data->GetDevice(IID_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf()))); + + Dml::ExecutionContext* context = nullptr; + uint32_t context_size = gsl::narrow_cast(sizeof(context)); + ORT_THROW_IF_FAILED(d3d12_device->GetPrivateData(execution_context_guid, &context_size, &context)); + + Dml::PooledUploadHeap* upload_heap = nullptr; + uint32_t upload_heap_size = gsl::narrow_cast(sizeof(upload_heap)); + ORT_THROW_IF_FAILED(d3d12_device->GetPrivateData(upload_heap_guid, &upload_heap_size, &upload_heap)); + + upload_heap->BeginUploadToGpu( + dst_data, 0, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, gsl::make_span(static_cast(src), num_bytes)); + context->Flush(); + + // We don't use the same command queue as the execution provider, so we need to sync to make sure that all data has + // been uploaded to the resource. This function is usually called before inference just to upload initial data to the + // GPU, so it shouldn't be a bottleneck. + context->GetCurrentCompletionEvent().WaitForSignal(); +} + +void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { + const auto* allocInfo = static_cast(src); + ID3D12Resource* src_data = allocInfo->GetResource(); + + ComPtr d3d12_device; + ORT_THROW_IF_FAILED(src_data->GetDevice(IID_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf()))); + + Dml::ExecutionContext* context = nullptr; + uint32_t context_size = gsl::narrow_cast(sizeof(context)); + ORT_THROW_IF_FAILED(d3d12_device->GetPrivateData(execution_context_guid, &context_size, &context)); + + Dml::ReadbackHeap* readback_heap = nullptr; + uint32_t readback_heap_size = gsl::narrow_cast(sizeof(readback_heap)); + ORT_THROW_IF_FAILED(d3d12_device->GetPrivateData(dml_readback_heap_guid, &readback_heap_size, &readback_heap)); + + // ReadbackFromGpu already syncs with the CPU and waits for the copy to be completed, so we don't need to sync after + // this call + readback_heap->ReadbackFromGpu( + gsl::make_span(static_cast(dst), num_bytes), src_data, 0, D3D12_RESOURCE_STATE_UNORDERED_ACCESS); +} + +const std::unordered_map* GetDmlToHostMemCpyFunction() { + static std::unordered_map map{ + {OrtDevice::GPU, DmlToCpuMemCpy}}; + + return ↦ +} + #endif #ifdef USE_CANN diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.h b/onnxruntime/python/onnxruntime_pybind_mlvalue.h index 4ac9c70468..e3f277bcb9 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.h +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.h @@ -77,6 +77,12 @@ std::unique_ptr GetGPUDataTransfer(); AllocatorPtr GetDmlAllocator(OrtDevice::DeviceId id); +void CpuToDmlMemCpy(void* dst, const void* src, size_t num_bytes); + +void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes); + +const std::unordered_map* GetDmlToHostMemCpyFunction(); + #endif #ifdef USE_CANN diff --git a/onnxruntime/python/onnxruntime_pybind_module.cc b/onnxruntime/python/onnxruntime_pybind_module.cc index f320707697..761757535b 100644 --- a/onnxruntime/python/onnxruntime_pybind_module.cc +++ b/onnxruntime/python/onnxruntime_pybind_module.cc @@ -11,6 +11,7 @@ namespace python { namespace py = pybind11; void CreateInferencePybindStateModule(py::module& m); +void CreateQuantPybindModule(py::module& m); PYBIND11_MODULE(onnxruntime_pybind11_state, m) { CreateInferencePybindStateModule(m); @@ -23,6 +24,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { m.def("get_version_string", []() -> std::string { return ORT_VERSION; }); m.def("get_build_info", []() -> std::string { return ORT_BUILD_INFO; }); + CreateQuantPybindModule(m); } } // namespace python } // namespace onnxruntime diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index f9d908e0ac..dc4a4dcc13 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -63,7 +63,12 @@ void addOrtValueMethods(pybind11::module& m) { // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA CreateGenericMLValue(nullptr, GetRocmAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToRocmMemCpy); - +#elif USE_DML + // InputDeflist is null because OrtValue creation is not tied to a specific model + // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) + // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML + CreateGenericMLValue( + nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy); #else throw std::runtime_error( "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " @@ -126,6 +131,12 @@ void addOrtValueMethods(pybind11::module& m) { values_type, *(ml_value->GetMutable()), CpuToRocmMemCpy); +#elif USE_DML + onnxruntime::python::CopyDataToTensor( + py_values, + values_type, + *(ml_value->GetMutable()), + CpuToDmlMemCpy); #else throw std::runtime_error( "Unsupported GPU device: Cannot find the supported GPU device."); @@ -158,12 +169,18 @@ void addOrtValueMethods(pybind11::module& m) { throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); } allocator = GetCudaAllocator(device.Id()); -#elif USE_DML - allocator = GetDmlAllocator(device.Id()); #else throw std::runtime_error( "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " "Please use the CUDA package of OnnxRuntime to use this feature."); +#endif + } else if (strcmp(GetDeviceName(device), DML) == 0) { +#if USE_DML + allocator = GetDmlAllocator(device.Id()); +#else + throw std::runtime_error( + "Can't allocate memory on the DirectML device using this package of OnnxRuntime. " + "Please use the DirectML package of OnnxRuntime to use this feature."); #endif } else { throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device"); @@ -290,11 +307,13 @@ void addOrtValueMethods(pybind11::module& m) { #ifdef USE_CUDA GetPyObjFromTensor(ml_value->Get(), obj, nullptr, GetCudaToHostMemCpyFunction()); #elif USE_ROCM - GetPyObjFromTensor(ml_value->Get(), obj, nullptr, GetRocmToHostMemCpyFunction()); + GetPyObjFromTensor(ml_value->Get(), obj, nullptr, GetRocmToHostMemCpyFunction()); #elif USE_CANN - GetPyObjFromTensor(ml_value->Get(), obj, nullptr, GetCannToHostMemCpyFunction()); + GetPyObjFromTensor(ml_value->Get(), obj, nullptr, GetCannToHostMemCpyFunction()); +#elif USE_DML + GetPyObjFromTensor(ml_value->Get(), obj, nullptr, GetDmlToHostMemCpyFunction()); #else - GetPyObjFromTensor(ml_value->Get(), obj, nullptr, nullptr); + GetPyObjFromTensor(ml_value->Get(), obj, nullptr, nullptr); #endif return obj; }) diff --git a/onnxruntime/python/onnxruntime_pybind_quant.cc b/onnxruntime/python/onnxruntime_pybind_quant.cc new file mode 100644 index 0000000000..04dfa9b51e --- /dev/null +++ b/onnxruntime/python/onnxruntime_pybind_quant.cc @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +#include "contrib_ops/cpu/quantization/dequantize_blockwise.h" +#include "contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h" +#include "core/util/thread_utils.h" + +namespace pybind11 { +namespace detail { +// python3 -c 'import numpy as np; print(np.dtype(np.float16).num)' +constexpr int NPY_FLOAT16 = 23; +template <> +struct npy_format_descriptor { + static constexpr auto name = _("float16"); + static pybind11::dtype dtype() { + handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16); + return reinterpret_borrow(ptr); + } + static std::string format() { + // following: https://docs.python.org/3/library/struct.html#format-characters + return "e"; + } +}; +} // namespace detail +} // namespace pybind11 + +namespace onnxruntime { +namespace python { + +namespace py = pybind11; +using namespace onnxruntime; + +template +void QuantizeMatMul4BitsBlockwise( + py::array_t dst, // shape: [ N, block_per_K, block_blob_size ] + py::array_t src, // shape: [K, N] + py::array_t scale, // shape: [N, block_per_K] + py::array_t zero_points, // shape: [N, block_per_K] if bits > 4 else [N, (block_per_K + 1) / 2] + int32_t block_size, + int32_t N, + int32_t K, + bool is_symmetric) { + OrtThreadPoolParams to; + auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, + concurrency::ThreadPoolType::INTRA_OP); + + py::buffer_info dst_buf = dst.request(); + py::buffer_info src_buf = src.request(); + py::buffer_info scale_buf = scale.request(); + py::buffer_info zp_buf = zero_points.request(); + + contrib::QuantizeBlockwise( + static_cast(dst_buf.ptr), + static_cast(src_buf.ptr), + static_cast(scale_buf.ptr), + is_symmetric ? nullptr : static_cast(zp_buf.ptr), + block_size, + 4, + N, + K, + tp.get()); +} + +template +void QuantizeMatMulBnb4Blockwise( + py::array_t dst, + py::array_t src, + py::array_t absmax, + int32_t block_size, + int32_t quant_type, + int32_t N, + int32_t K) { + OrtThreadPoolParams to; + auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, + concurrency::ThreadPoolType::INTRA_OP); + + py::buffer_info dst_buf = dst.request(); + py::buffer_info src_buf = src.request(); + py::buffer_info absmax_buf = absmax.request(); + + contrib::QuantizeBlockwiseBnb4( + static_cast(dst_buf.ptr), + static_cast(src_buf.ptr), + static_cast(absmax_buf.ptr), + block_size, + quant_type, + N, + K, + tp.get()); +} + +void CreateQuantPybindModule(py::module& m) { + m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise); + m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise); + m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise); + m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise); +} + +} // namespace python +} // namespace onnxruntime diff --git a/onnxruntime/python/onnxruntime_pybind_schema.cc b/onnxruntime/python/onnxruntime_pybind_schema.cc index a8c217b0ff..3a97777287 100644 --- a/onnxruntime/python/onnxruntime_pybind_schema.cc +++ b/onnxruntime/python/onnxruntime_pybind_schema.cc @@ -59,7 +59,7 @@ void addGlobalSchemaFunctions(pybind11::module& m) { onnxruntime::ArmNNProviderFactoryCreator::Create(0), #endif #ifdef USE_DML - onnxruntime::DMLProviderFactoryCreator::Create(0, /*skip_software_device_check*/ true), + onnxruntime::DMLProviderFactoryCreator::Create(0, false, false, false), #endif #ifdef USE_NNAPI onnxruntime::NnapiProviderFactoryCreator::Create(0, std::optional()), diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 5ac20739c4..c336b1537a 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -53,6 +53,7 @@ namespace onnxruntime { #endif // _MSC_VER #include +#include #if defined(_MSC_VER) #pragma warning(disable : 4267 4996 4503 4003) @@ -85,7 +86,7 @@ struct AsyncResource { std::vector feed_names; std::vector feed_names_raw; - std::vector fetches_raw; + std::vector fetches_raw; // will be released during destruction std::vector fetch_names; std::vector fetch_names_raw; @@ -106,6 +107,15 @@ struct AsyncResource { fetch_names.reserve(sz); fetch_names_raw.reserve(sz); } + + ~AsyncResource() { + std::for_each(fetches_raw.begin(), fetches_raw.end(), [](const OrtValue* fetch) { + if (fetch) { + std::unique_ptr fetch_recycler(fetch); + } + }); + fetches_raw.clear(); + } }; void AsyncCallback(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr ort_status) { @@ -227,7 +237,11 @@ const char* GetDeviceName(const OrtDevice& device) { case OrtDevice::CPU: return CPU; case OrtDevice::GPU: +#ifdef USE_DML + return DML; +#else return CUDA; +#endif case OrtDevice::FPGA: return "FPGA"; case OrtDevice::NPU: @@ -873,18 +887,10 @@ std::unique_ptr CreateExecutionProviderInstance( #endif } else if (type == kDmlExecutionProvider) { #ifdef USE_DML - int device_id = 0; - auto it = provider_options_map.find(type); - if (it != provider_options_map.end()) { - for (auto option : it->second) { - if (option.first == "device_id") { - if (!option.second.empty()) { - device_id = std::stoi(option.second); - } - } - } - } - return onnxruntime::DMLProviderFactoryCreator::Create(device_id)->CreateProvider(); + auto cit = provider_options_map.find(type); + return onnxruntime::DMLProviderFactoryCreator::CreateFromProviderOptions( + cit == provider_options_map.end() ? ProviderOptions{} : cit->second) + ->CreateProvider(); #endif } else if (type == kNnapiExecutionProvider) { #if defined(USE_NNAPI) diff --git a/onnxruntime/python/tools/kernel_explorer/device_array.h b/onnxruntime/python/tools/kernel_explorer/device_array.h index bb868c2b7a..12c526fa0c 100644 --- a/onnxruntime/python/tools/kernel_explorer/device_array.h +++ b/onnxruntime/python/tools/kernel_explorer/device_array.h @@ -62,8 +62,8 @@ class DeviceArray { private: std::shared_ptr device_; void* host_; - ssize_t size_; - ssize_t itemsize_; + py::ssize_t size_; + py::ssize_t itemsize_; }; } // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_bnb4.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_bnb4.cu new file mode 100644 index 0000000000..3504ce1beb --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_bnb4.cu @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file serve as a simple example for adding a tunable op to onnxruntime. + +#include +#include +#include + +#include + +#include "core/providers/cuda/tunable/cuda_tunable.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" +#include "python/tools/kernel_explorer/device_array.h" +#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh" + +namespace py = pybind11; + +namespace onnxruntime { + +// Extend the OpParams so that all specializations have the same parameter passing interface +template +struct DequantizeBnb4Params : cuda::tunable::OpParams { + std::string Signature() const override { return std::to_string(n_); } + + int quant_type_; + T* output_; + const uint8_t* quant_; + const T* absmax_; + T* quant_map_buffer_; + int n_; + int k_; +}; + +template +class DequantizeBnb4 : public IKernelExplorer { + public: + DequantizeBnb4( + int quant_type, + DeviceArray& output, + DeviceArray& quant, + DeviceArray& absmax, + DeviceArray& quant_map_buffer, + int n, int k) { + params_.tuning_ctx = TuningContext(); + params_.stream = Stream(); + params_.quant_type_ = quant_type; + params_.output_ = static_cast(output.ptr()); + params_.quant_ = static_cast(quant.ptr()); + params_.absmax_ = static_cast(absmax.ptr()); + params_.quant_map_buffer_ = static_cast(quant_map_buffer.ptr()); + params_.n_ = n; + params_.k_ = k; + } + + void Run() override { + ORT_THROW_IF_ERROR(contrib::cuda::SetBnbQuantMap( + params_.quant_type_, + params_.quant_map_buffer_, + params_.StreamHandle())); + ORT_THROW_IF_ERROR(contrib::cuda::DequantizeBnb4( + params_.quant_map_buffer_, + params_.output_, + params_.quant_, + params_.absmax_, + 64, + params_.n_ * params_.k_, + params_.StreamHandle())); + } + + private: + // A VectorAddOp is a callable that can process const VectorAddParams* + using ParamsT = DequantizeBnb4Params; + ParamsT params_{}; +}; + +#define REGISTER_OP(name, type) \ + py::class_>(m, #name "_" #type) \ + .def(py::init()) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ + .def("Run", &name::Run); + +KE_REGISTER(m) { + REGISTER_OP(DequantizeBnb4, half); + REGISTER_OP(DequantizeBnb4, float); +} + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_int4.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_int4.cu new file mode 100644 index 0000000000..9b5e4079a7 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_int4.cu @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file serve as a simple example for adding a tunable op to onnxruntime. + +#include +#include +#include + +#include + +#include "core/providers/cuda/tunable/cuda_tunable.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" +#include "python/tools/kernel_explorer/device_array.h" +#include "contrib_ops/cuda/quantization/dequantize_blockwise.cuh" + +namespace py = pybind11; + +namespace onnxruntime { + +// Extend the OpParams so that all specializations have the same parameter passing interface +template +struct DequantizeInt4Params : cuda::tunable::OpParams { + std::string Signature() const override { return std::to_string(n_); } + + T* output_; + const uint8_t* quant_; + const T* scales_; + const uint8_t* zero_points_; + int n_; + int k_; +}; + +template +class DequantizeInt4 : public IKernelExplorer { + public: + DequantizeInt4(DeviceArray& output, DeviceArray& quant, DeviceArray& scales, int n, int k) { + params_.tuning_ctx = TuningContext(); + params_.stream = Stream(); + params_.output_ = static_cast(output.ptr()); + params_.quant_ = static_cast(quant.ptr()); + params_.scales_ = static_cast(scales.ptr()); + params_.zero_points_ = nullptr; + params_.n_ = n; + params_.k_ = k; + } + + void Run() override { + ORT_THROW_IF_ERROR(contrib::cuda::Dequantize4Bits( + params_.output_, + params_.quant_, + params_.scales_, + params_.zero_points_, + params_.k_, + params_.n_, + 32, + params_.StreamHandle())); + } + + private: + // A VectorAddOp is a callable that can process const VectorAddParams* + using ParamsT = DequantizeInt4Params; + ParamsT params_{}; +}; + +#define REGISTER_OP(name, type) \ + py::class_>(m, #name "_" #type) \ + .def(py::init()) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ + .def("Run", &name::Run); + +KE_REGISTER(m) { + REGISTER_OP(DequantizeInt4, half); + REGISTER_OP(DequantizeInt4, float); +} + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu new file mode 100644 index 0000000000..fd9e9c4fd1 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file serve as a simple example for adding a tunable op to onnxruntime. + +#include +#include + +#include + +#include + +#include "core/providers/cuda/tunable/cuda_tunable.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cuda/cuda_stream_handle.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" +#include "python/tools/kernel_explorer/kernels/vector_add_kernel.cuh" +#include "contrib_ops/cuda/quantization/matmul_nbits.cuh" + +namespace py = pybind11; + +namespace onnxruntime { + +// Extend the OpParams so that all specializations have the same parameter passing interface +template +struct GemmBenchmarkParams : cuda::tunable::OpParams { + std::string Signature() const override { return std::to_string(n_); } + + T* output_; + const T* a_; + const T* b_; + int m_; + int n_; + int k_; + cublasHandle_t cublas_handle; +}; + +template +class GemmBenchmark : public IKernelExplorer { + public: + GemmBenchmark(DeviceArray& output, DeviceArray& a, DeviceArray& b, int m, int n, int k) { + params_.tuning_ctx = TuningContext(); + params_.stream = Stream(); + params_.output_ = static_cast(output.ptr()); + params_.a_ = static_cast(a.ptr()); + params_.b_ = static_cast(b.ptr()); + params_.m_ = m; + params_.n_ = n; + params_.k_ = k; + + CUBLAS_CALL_THROW(cublasCreate(&(params_.cublas_handle))); + CUDA_CALL_THROW(cudaGetDeviceProperties(&device_prop_, 0)); + } + + void Run() override { + typedef typename ToCudaType::MappedType CudaT; + CudaT one = ToCudaType::FromFloat(1.0f); + CudaT zero = ToCudaType::FromFloat(0.0f); + CUBLAS_CALL_THROW(cublasGemmHelper( + params_.cublas_handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + params_.n_, params_.m_, params_.k_, + &one, + reinterpret_cast(params_.b_), + params_.n_, + reinterpret_cast(params_.a_), + params_.k_, + &zero, + params_.output_, + params_.n_, + device_prop_)); + } + + private: + // A VectorAddOp is a callable that can process const VectorAddParams* + using ParamsT = GemmBenchmarkParams; + ParamsT params_{}; + cudaDeviceProp device_prop_; +}; + +#define REGISTER_OP(name, type) \ + py::class_>(m, #name "_" #type) \ + .def(py::init()) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ + .def("Run", &name::Run); + +KE_REGISTER(m) { + REGISTER_OP(GemmBenchmark, half); + REGISTER_OP(GemmBenchmark, float); +} + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_4bits.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_4bits.cu new file mode 100644 index 0000000000..9e8c4cd7be --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_4bits.cu @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file serve as a simple example for adding a tunable op to onnxruntime. + +#include +#include +#include + +#include + +#include "core/providers/cuda/tunable/cuda_tunable.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" +#include "python/tools/kernel_explorer/kernels/vector_add_kernel.cuh" +#include "contrib_ops/cuda/quantization/matmul_nbits.cuh" + +namespace py = pybind11; + +namespace onnxruntime { + +// Extend the OpParams so that all specializations have the same parameter passing interface +template +struct MatrixFloatInt4Params : cuda::tunable::OpParams { + std::string Signature() const override { return std::to_string(n_); } + + T* output_; + const T* a_; + const uint8_t* b_; + const T* scales_; + const uint8_t* zero_points_; + int m_; + int n_; + int k_; +}; + +template +class MatrixFloatInt4 : public IKernelExplorer { + public: + MatrixFloatInt4(DeviceArray& output, + DeviceArray& a, + DeviceArray& b, + DeviceArray& scales, + int m, int n, int k) { + params_.tuning_ctx = TuningContext(); + params_.stream = Stream(); + params_.output_ = static_cast(output.ptr()); + params_.a_ = static_cast(a.ptr()); + params_.b_ = static_cast(b.ptr()); + params_.scales_ = static_cast(scales.ptr()); + params_.zero_points_ = nullptr; + params_.m_ = m; + params_.n_ = n; + params_.k_ = k; + + CUDA_CALL_THROW(cudaGetDeviceProperties(&device_prop_, 0)); + } + + MatrixFloatInt4(DeviceArray& output, + DeviceArray& a, + DeviceArray& b, + DeviceArray& scales, + DeviceArray& zeropoints, + int m, int n, int k) : MatrixFloatInt4(output, a, b, scales, m, n, k) { + params_.zero_points_ = static_cast(zeropoints.ptr()); + } + + void Run() override { + contrib::cuda::TryMatMul4Bits( + params_.output_, + params_.a_, + params_.b_, + params_.scales_, + params_.zero_points_, + params_.m_, + params_.n_, + params_.k_, + 32, + static_cast(device_prop_.sharedMemPerBlock), + params_.StreamHandle()); + } + + private: + // A VectorAddOp is a callable that can process const VectorAddParams* + using ParamsT = MatrixFloatInt4Params; + ParamsT params_{}; + cudaDeviceProp device_prop_; +}; + +#define REGISTER_OP(name, type) \ + py::class_>(m, #name "_" #type) \ + .def(py::init()) \ + .def(py::init()) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ + .def("Run", &name::Run); + +KE_REGISTER(m) { + REGISTER_OP(MatrixFloatInt4, half); + REGISTER_OP(MatrixFloatInt4, float); +} + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_bnb4.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_bnb4.cu new file mode 100644 index 0000000000..e4cd835653 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_bnb4.cu @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file serve as a simple example for adding a tunable op to onnxruntime. + +#include +#include +#include + +#include + +#include "core/providers/cuda/tunable/cuda_tunable.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" +#include "python/tools/kernel_explorer/kernels/vector_add_kernel.cuh" +#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh" +#include "contrib_ops/cuda/quantization/matmul_bnb4.cuh" + +namespace py = pybind11; + +namespace onnxruntime { + +// Extend the OpParams so that all specializations have the same parameter passing interface +template +struct MatrixFloatBnb4Params : cuda::tunable::OpParams { + std::string Signature() const override { return std::to_string(n_); } + + int quant_type_; + T* output_; + const T* a_; + const uint8_t* b_; + const T* absmax_; + T* quant_map_buffer_; + int m_; + int n_; + int k_; +}; + +template +class MatrixFloatBnb4 : public IKernelExplorer { + public: + MatrixFloatBnb4(DeviceArray& output, + DeviceArray& a, + DeviceArray& b, + DeviceArray& absmax, + DeviceArray& quant_map_buffer, + int quant_type, int m, int n, int k) { + params_.tuning_ctx = TuningContext(); + params_.stream = Stream(); + params_.output_ = static_cast(output.ptr()); + params_.a_ = static_cast(a.ptr()); + params_.b_ = static_cast(b.ptr()); + params_.absmax_ = static_cast(absmax.ptr()); + params_.quant_map_buffer_ = static_cast(quant_map_buffer.ptr()); + params_.quant_type_ = quant_type; + params_.m_ = m; + params_.n_ = n; + params_.k_ = k; + } + + void Run() override { + ORT_THROW_IF_ERROR(contrib::cuda::SetBnbQuantMap( + params_.quant_type_, + params_.quant_map_buffer_, + params_.StreamHandle())); + contrib::cuda::TryMatMulBnb4( + params_.quant_map_buffer_, + params_.output_, + params_.a_, + params_.b_, + params_.absmax_, + params_.m_, + params_.n_, + params_.k_, + 64, + params_.StreamHandle()); + } + + private: + // A VectorAddOp is a callable that can process const VectorAddParams* + using ParamsT = MatrixFloatBnb4Params; + ParamsT params_{}; +}; + +#define REGISTER_OP(name, type) \ + py::class_>(m, #name "_" #type) \ + .def(py::init()) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ + .def("Run", &name::Run); + +KE_REGISTER(m) { + REGISTER_OP(MatrixFloatBnb4, half); + REGISTER_OP(MatrixFloatBnb4, float); +} + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_bnb4.py b/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_bnb4.py new file mode 100644 index 0000000000..140151aadc --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_bnb4.py @@ -0,0 +1,92 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import sys +from dataclasses import dataclass + +import kernel_explorer as ke +import numpy as np +from utils import dtype_to_bytes + + +def dtype_to_funcs(dtype): + type_map = { + "float16": list(filter(lambda x: "DequantizeBnb4_half" in x, dir(ke))), + "float32": list(filter(lambda x: "DequantizeBnb4_float" in x, dir(ke))), + } + return type_map[dtype] + + +quant_enums = {"FP4": 0, "NF4": 1} + + +dtypes = ["float16", "float32"] +quant_types = ["FP4", "NF4"] + + +@dataclass +class DequantizeBnb4Metric(ke.BandwidthMetric): + quant_type: str + n: int + k: int + + def report(self): + return ( + f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s" + f" {self.quant_type} {self.dtype} n={self.n} k={self.k} {self.name}" + ) + + +def profile_dequantize_int4_func(qt, n, k, dtype, func): + np.random.seed(0) + block_size = 64 + numel = n * k + output = np.random.rand(n, k).astype(dtype) + quant = np.random.randint(low=0, high=255, size=(numel + 1) // 2).astype("uint8") + absmax = np.random.rand((numel + block_size - 1) // block_size).astype(dtype) + quant_map_buffer = np.zeros(16).astype(dtype) + + output_d = ke.DeviceArray(output) + quant_d = ke.DeviceArray(quant) + absmax_d = ke.DeviceArray(absmax) + quant_map_buffer_d = ke.DeviceArray(quant_map_buffer) + f = getattr(ke, func) + my_op = f(quant_enums[qt], output_d, quant_d, absmax_d, quant_map_buffer_d, n, k) + duration_ms = my_op.Profile() + total_bytes = numel / 2 + (numel + numel / block_size) * dtype_to_bytes(dtype) + + ke.report(DequantizeBnb4Metric(func, dtype, duration_ms, total_bytes, qt, n, k)) + + +def profile_with_args(qt, n, k, dtype, sort): + with ke.benchmark(sort): + for func in dtype_to_funcs(dtype): + profile_dequantize_int4_func(qt, n, k, dtype, func) + + +def profile(): + for qt in quant_types: + for dt in dtypes: + for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)): + profile_with_args(qt, n, k, dt, True) + print() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + group = parser.add_argument_group("profile with args") + group.add_argument("n", type=int) + group.add_argument("k", type=int) + group.add_argument("quant_type", choices=quant_types) + group.add_argument("dtype", choices=dtypes) + group.add_argument("--sort", action="store_true") + + if len(sys.argv) == 1: + profile() + else: + args = parser.parse_args() + profile_with_args(args.quant_type, args.n, args.k, args.dtype, args.sort) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_int4.py b/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_int4.py new file mode 100644 index 0000000000..7088039f9e --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_int4.py @@ -0,0 +1,78 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import sys +from dataclasses import dataclass + +import kernel_explorer as ke +import numpy as np +from utils import dtype_to_bytes + + +def dtype_to_funcs(dtype): + type_map = { + "float16": list(filter(lambda x: "DequantizeInt4_half" in x, dir(ke))), + "float32": list(filter(lambda x: "DequantizeInt4_float" in x, dir(ke))), + } + return type_map[dtype] + + +dtypes = ["float16", "float32"] + + +@dataclass +class DequantizeInt4Metric(ke.BandwidthMetric): + n: int + k: int + + def report(self): + return f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} n={self.n} k={self.k} {self.name}" + + +def profile_dequantize_int4_func(n, k, dtype, func): + np.random.seed(0) + output = np.random.rand(n, k).astype(dtype) + quant = np.random.randint(low=0, high=127, size=(n, (k + 31) // 32, 16)).astype("uint8") + scales = np.random.rand(n, (k + 31) // 32).astype(dtype) + + output_d = ke.DeviceArray(output) + quant_d = ke.DeviceArray(quant) + scales_d = ke.DeviceArray(scales) + f = getattr(ke, func) + my_op = f(output_d, quant_d, scales_d, n, k) + duration_ms = my_op.Profile() + total_bytes = (n * k) / 2 + (n * k + n * k / 32) * dtype_to_bytes(dtype) + + ke.report(DequantizeInt4Metric(func, dtype, duration_ms, total_bytes, n, k)) + + +def profile_with_args(n, k, dtype, sort): + with ke.benchmark(sort): + for func in dtype_to_funcs(dtype): + profile_dequantize_int4_func(n, k, dtype, func) + + +def profile(): + for dt in dtypes: + for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)): + profile_with_args(n, k, dt, True) + print() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + group = parser.add_argument_group("profile with args") + group.add_argument("n", type=int) + group.add_argument("k", type=int) + group.add_argument("dtype", choices=dtypes) + group.add_argument("--sort", action="store_true") + + if len(sys.argv) == 1: + profile() + else: + args = parser.parse_args() + profile_with_args(args.n, args.k, args.dtype, args.sort) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py index e378f3e1cc..8182cdb175 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_test.py @@ -179,6 +179,7 @@ def profile_with_args(dtype, transa, transb, m, n, k, sort): profile_gemm_func(getattr(ke, "RocBlasGemm" + dtype_suffix), dtype, transa, transb, m, n, k) profile_gemm_func(getattr(ke, "CKGemm" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k) profile_gemm_func(getattr(ke, "GemmTunable" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k) + profile_gemm_func(getattr(ke, "GemmBenchmark" + dtype_suffix), dtype, transa, transb, m, n, k) if ke.is_hipblaslt_available(): profile_gemm_func( getattr(ke, "GemmHipBlasLt" + dtype_suffix + transab_suffix), dtype, transa, transb, m, n, k diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py new file mode 100644 index 0000000000..9cb937a13f --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_4bits.py @@ -0,0 +1,132 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import sys +from dataclasses import dataclass + +import kernel_explorer as ke +import numpy as np +from utils import dtype_to_bytes + + +def dtype_to_funcs(dtype): + type_map = { + "float16": list(filter(lambda x: "MatrixFloatInt4_half" in x, dir(ke))), + "float32": list(filter(lambda x: "MatrixFloatInt4_float" in x, dir(ke))), + } + return type_map[dtype] + + +def dtype_to_funcs_cublas(dtype): + type_map = { + "float16": list(filter(lambda x: "GemmBenchmark_half" in x, dir(ke))), + "float32": list(filter(lambda x: "GemmBenchmark_float" in x, dir(ke))), + } + return type_map[dtype] + + +dtypes = ["float16", "float32"] + + +@dataclass +class MatrixMulMetric(ke.BandwidthMetric): + m: int + n: int + k: int + + def report(self): + return ( + f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} m={self.m} n={self.n} k={self.k} {self.name}" + ) + + +@dataclass +class MatrixFpInt4Metric(MatrixMulMetric): + is_symmetric: bool + + def report(self): + return f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} m={self.m} n={self.n} k={self.k} is_symmetric={self.is_symmetric} {self.name}" + + +def profile_matmul_fp_int4_func(m, n, k, dtype, func, is_symmetric): + np.random.seed(0) + output = np.random.rand(m, n).astype(dtype) + a = np.random.rand(m, k).astype(dtype) + b = np.random.randint(low=0, high=127, size=(n, (k + 31) // 32, 16)).astype("uint8") + scales = np.random.rand(n * ((k + 31) // 32)).astype(dtype) + zeropoints = np.random.rand((n * ((k + 31) // 32) + 1) // 2).astype(dtype) + + output_d = ke.DeviceArray(output) + a_d = ke.DeviceArray(a) + b_d = ke.DeviceArray(b) + scales_d = ke.DeviceArray(scales) + zeropoints_d = ke.DeviceArray(zeropoints) + f = getattr(ke, func) + + my_op = ( + f(output_d, a_d, b_d, scales_d, m, n, k) + if is_symmetric + else f(output_d, a_d, b_d, scales_d, zeropoints_d, m, n, k) + ) + duration_ms = my_op.Profile() + total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype)) + + ke.report(MatrixFpInt4Metric(func, dtype, duration_ms, total_bytes, m, n, k, is_symmetric)) + + +def profile_gemm_func(m, n, k, dtype, func): + np.random.seed(0) + output = np.random.rand(m, n).astype(dtype) + a = np.random.rand(m, k).astype(dtype) + b = np.random.rand(k, n).astype(dtype) + + output_d = ke.DeviceArray(output) + a_d = ke.DeviceArray(a) + b_d = ke.DeviceArray(b) + f = getattr(ke, func) + my_op = f(output_d, a_d, b_d, m, n, k) + duration_ms = my_op.Profile() + total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype)) + + ke.report(MatrixMulMetric(func, dtype, duration_ms, total_bytes, m, n, k)) + + +def profile_with_args(m, n, k, dtype, sort): + with ke.benchmark(sort): + for func in dtype_to_funcs(dtype): + profile_matmul_fp_int4_func(m, n, k, dtype, func, True) + + for func in dtype_to_funcs(dtype): + profile_matmul_fp_int4_func(m, n, k, dtype, func, False) + + for func in dtype_to_funcs_cublas(dtype): + profile_gemm_func(m, n, k, dtype, func) + + +def profile(): + dims_m = [1] + for dt in dtypes: + for m in dims_m: + for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)): + profile_with_args(m, n, k, dt, False) + print() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + group = parser.add_argument_group("profile with args") + group.add_argument("m", type=int) + group.add_argument("n", type=int) + group.add_argument("k", type=int) + group.add_argument("dtype", choices=dtypes) + group.add_argument("--sort", action="store_true") + + if len(sys.argv) == 1: + profile() + else: + args = parser.parse_args() + profile_with_args(args.m, args.n, args.k, args.dtype, args.sort) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/matmul_bnb4.py b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_bnb4.py new file mode 100644 index 0000000000..4a9489050f --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_bnb4.py @@ -0,0 +1,136 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import sys +from dataclasses import dataclass + +import kernel_explorer as ke +import numpy as np +from utils import dtype_to_bytes + + +def dtype_to_funcs(dtype): + type_map = { + "float16": list(filter(lambda x: "MatrixFloatBnb4_half" in x, dir(ke))), + "float32": list(filter(lambda x: "MatrixFloatBnb4_float" in x, dir(ke))), + } + return type_map[dtype] + + +def dtype_to_funcs_cublas(dtype): + type_map = { + "float16": list(filter(lambda x: "GemmBenchmark_half" in x, dir(ke))), + "float32": list(filter(lambda x: "GemmBenchmark_float" in x, dir(ke))), + } + return type_map[dtype] + + +quant_enums = {"FP4": 0, "NF4": 1} + + +dtypes = ["float16", "float32"] +quant_types = ["FP4", "NF4"] + + +@dataclass +class MatrixMulMetric(ke.BandwidthMetric): + m: int + n: int + k: int + + def report(self): + return ( + f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} m={self.m} n={self.n} k={self.k} {self.name}" + ) + + +@dataclass +class MatrixFpBnb4Metric(MatrixMulMetric): + quant_type: str + + def report(self): + return ( + f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s" + f" {self.quant_type} {self.dtype} m={self.m} n={self.n} k={self.k} {self.name}" + ) + + +def profile_matmul_fp_bnb4_func(qt, m, n, k, dtype, func): + np.random.seed(0) + block_size = 64 + numel = n * k + output = np.random.rand(m, n).astype(dtype) + a = np.random.rand(m, k).astype(dtype) + b = np.random.randint(low=0, high=255, size=(numel + 1) // 2).astype("uint8") + absmax = np.random.rand((numel + block_size - 1) // block_size).astype(dtype) + quant_map_buffer = np.zeros(16).astype(dtype) + + output_d = ke.DeviceArray(output) + a_d = ke.DeviceArray(a) + b_d = ke.DeviceArray(b) + absmax_d = ke.DeviceArray(absmax) + quant_map_buffer_d = ke.DeviceArray(quant_map_buffer) + f = getattr(ke, func) + + my_op = f(output_d, a_d, b_d, absmax_d, quant_map_buffer_d, quant_enums[qt], m, n, k) + duration_ms = my_op.Profile() + total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype)) + + ke.report(MatrixFpBnb4Metric(func, dtype, duration_ms, total_bytes, m, n, k, qt)) + + +def profile_gemm_func(m, n, k, dtype, func): + np.random.seed(0) + output = np.random.rand(m, n).astype(dtype) + a = np.random.rand(m, k).astype(dtype) + b = np.random.rand(k, n).astype(dtype) + + output_d = ke.DeviceArray(output) + a_d = ke.DeviceArray(a) + b_d = ke.DeviceArray(b) + f = getattr(ke, func) + my_op = f(output_d, a_d, b_d, m, n, k) + duration_ms = my_op.Profile() + total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype)) + + ke.report(MatrixMulMetric(func, dtype, duration_ms, total_bytes, m, n, k)) + + +def profile_with_args(qt, m, n, k, dtype, sort): + with ke.benchmark(sort): + for func in dtype_to_funcs(dtype): + profile_matmul_fp_bnb4_func(qt, m, n, k, dtype, func) + + for func in dtype_to_funcs_cublas(dtype): + profile_gemm_func(m, n, k, dtype, func) + + +def profile(): + dims_m = [1] + for qt in quant_types: + for dt in dtypes: + for m in dims_m: + for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)): + profile_with_args(qt, m, n, k, dt, False) + print() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + group = parser.add_argument_group("profile with args") + group.add_argument("m", type=int) + group.add_argument("n", type=int) + group.add_argument("k", type=int) + group.add_argument("quant_type", choices=quant_types) + group.add_argument("dtype", choices=dtypes) + group.add_argument("--sort", action="store_true") + + if len(sys.argv) == 1: + profile() + else: + args = parser.parse_args() + profile_with_args(args.quant_type, args.m, args.n, args.k, args.dtype, args.sort) diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index bdf00f2110..eb570fea6d 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -22,7 +22,7 @@ class TensorData: - _allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges"]) + _allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges", "bins"]) def __init__(self, **kwargs): for k, v in kwargs.items(): @@ -55,7 +55,7 @@ def __init__(self, calibration_method, data: Dict[str, Union[TensorData, Tuple]] self.data[k] = TensorData(lowest=v[0], highest=v[1]) continue if len(v) == 4: - self.data[k] = TensorData(lowest=v[0], highest=v[1], histogram=v[2], bins=v[3]) + self.data[k] = TensorData(lowest=v[0], highest=v[1], hist=v[2], bins=v[3]) continue raise TypeError(f"Unexpected tuple for {k:r}, it has {len(v)} elements: {v}.") if not isinstance(v, TensorData): @@ -115,6 +115,7 @@ def __init__( augmented_model_path="augmented_model.onnx", symmetric=False, use_external_data_format=False, + data_types_to_calibrate: list[TensorProto.DataType] = [TensorProto.FLOAT], ): """ :param model_path: ONNX model to calibrate. It should be a model file path @@ -138,6 +139,7 @@ def __init__( self.augment_model = None self.infer_session = None self.execution_providers = ["CPUExecutionProvider"] + self.tensor_types_to_calibrate = data_types_to_calibrate def set_execution_providers(self, execution_providers=["CPUExecutionProvider"]): # noqa: B006 """ @@ -171,7 +173,6 @@ def select_tensors_to_calibrate(self, model: ModelProto): initializer = {init.name for init in model.graph.initializer} tensors_to_calibrate = set() - tensor_type_to_calibrate = {TensorProto.FLOAT} for node in model.graph.node: if not self.op_types_to_calibrate or node.op_type in self.op_types_to_calibrate: @@ -180,7 +181,7 @@ def select_tensors_to_calibrate(self, model: ModelProto): vi = value_infos[tensor_name] if ( vi.type.HasField("tensor_type") - and (vi.type.tensor_type.elem_type in tensor_type_to_calibrate) + and (vi.type.tensor_type.elem_type in self.tensor_types_to_calibrate) and (tensor_name not in initializer) ): tensors_to_calibrate.add(tensor_name) @@ -224,6 +225,7 @@ def __init__( use_external_data_format=False, moving_average=False, averaging_constant=0.01, + data_types_to_calibrate: list[TensorProto.DataType] = [TensorProto.FLOAT], ): """ :param model_path: ONNX model to calibrate. It is a model path @@ -240,6 +242,7 @@ def __init__( augmented_model_path=augmented_model_path, symmetric=symmetric, use_external_data_format=use_external_data_format, + data_types_to_calibrate=data_types_to_calibrate, ) self.intermediate_outputs = [] self.calibrate_tensors_range = None @@ -256,7 +259,7 @@ def augment_graph(self): model and ensures their outputs are stored as part of the graph output :return: augmented ONNX model """ - tensors, _ = self.select_tensors_to_calibrate(self.model) + tensors, value_infos = self.select_tensors_to_calibrate(self.model) reshape_shape_name = str(uuid.uuid4()) reshape_shape = numpy_helper.from_array(np.array([1], dtype=np.int64), reshape_shape_name) self.model.graph.initializer.append(reshape_shape) @@ -280,8 +283,10 @@ def add_reduce_min_max(tensor_name, reduce_op_name): name=intermediate_output, ) + out_dtype = value_infos[tensor].type.tensor_type.elem_type + self.model.graph.node.extend([reduce_node, reshape_node]) - self.model.graph.output.append(helper.make_tensor_value_info(reduce_output, TensorProto.FLOAT, [1])) + self.model.graph.output.append(helper.make_tensor_value_info(reduce_output, out_dtype, [1])) for tensor in tensors: add_reduce_min_max(tensor, "ReduceMin") @@ -396,6 +401,7 @@ def __init__( num_quantized_bins=2048, percentile=99.999, scenario="same", + data_types_to_calibrate: list[TensorProto.DataType] = [TensorProto.FLOAT] ): """ :param model_path: ONNX model to calibrate. It is a model path. @@ -415,6 +421,7 @@ def __init__( augmented_model_path=augmented_model_path, symmetric=symmetric, use_external_data_format=use_external_data_format, + data_types_to_calibrate=data_types_to_calibrate, ) self.intermediate_outputs = [] self.calibrate_tensors_range = None @@ -515,6 +522,7 @@ def __init__( symmetric=False, num_bins=128, num_quantized_bins=128, + data_types_to_calibrate: list[TensorProto] = [TensorProto.FLOAT], ): """ :param model_path: ONNX model to calibrate. It is a model path @@ -535,6 +543,7 @@ def __init__( symmetric=symmetric, num_bins=num_bins, num_quantized_bins=num_quantized_bins, + data_types_to_calibrate=data_types_to_calibrate, ) @@ -549,6 +558,7 @@ def __init__( symmetric=False, num_bins=2048, percentile=99.999, + data_types_to_calibrate: list[TensorProto] = [TensorProto.FLOAT], ): """ :param model_path: ONNX model to calibrate. It is a model path @@ -569,6 +579,7 @@ def __init__( symmetric=symmetric, num_bins=num_bins, percentile=percentile, + data_types_to_calibrate=data_types_to_calibrate, ) @@ -582,6 +593,7 @@ def __init__( method="distribution", num_bins=128, scenario="same", + data_types_to_calibrate: list[TensorProto] = [TensorProto.FLOAT], ): """ :param model_path: ONNX model to calibrate. It is a model path @@ -604,6 +616,7 @@ def __init__( method=method, num_bins=num_bins, scenario=scenario, + data_types_to_calibrate=data_types_to_calibrate, ) @@ -1004,6 +1017,7 @@ def create_calibrator( calibrate_method=CalibrationMethod.MinMax, use_external_data_format=False, extra_options={}, # noqa: B006 + data_types_to_calibrate: list[TensorProto.DataType] = [TensorProto.FLOAT], ): calibrator = None if calibrate_method == CalibrationMethod.MinMax: @@ -1019,6 +1033,7 @@ def create_calibrator( symmetric=symmetric, moving_average=moving_average, averaging_constant=averaging_constant, + data_types_to_calibrate=data_types_to_calibrate, ) elif calibrate_method == CalibrationMethod.Entropy: # default settings for entropy algorithm @@ -1033,6 +1048,7 @@ def create_calibrator( symmetric=symmetric, num_bins=num_bins, num_quantized_bins=num_quantized_bins, + data_types_to_calibrate=data_types_to_calibrate, ) elif calibrate_method == CalibrationMethod.Percentile: # default settings for percentile algorithm @@ -1047,6 +1063,7 @@ def create_calibrator( symmetric=symmetric, num_bins=num_bins, percentile=percentile, + data_types_to_calibrate=data_types_to_calibrate, ) elif calibrate_method == CalibrationMethod.Distribution: @@ -1061,6 +1078,7 @@ def create_calibrator( use_external_data_format=use_external_data_format, num_bins=num_bins, scenario=scenario, + data_types_to_calibrate=data_types_to_calibrate, ) if calibrator: diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py new file mode 100644 index 0000000000..fea9e5e8cb --- /dev/null +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -0,0 +1,229 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import argparse +import logging +import os +from typing import List, Tuple + +import numpy as np +import numpy.typing as npt +import onnx +from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto + +from onnxruntime.capi._pybind_state import quantize_matmul_4bits + +from .onnx_model import ONNXModel +from .quant_utils import attribute_to_kwarg + +logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.INFO) +logger = logging.getLogger(__name__) + + +class MatMul4BitsQuantizer: + """Perform 4b quantization of constant MatMul weights""" + + def __init__(self, model: ModelProto, block_size: int, is_symmetric: bool, nodes_to_exclude=None): + if nodes_to_exclude is None: + nodes_to_exclude = [] + self.model = ONNXModel(model) + self.block_size = block_size + self.is_symmetric = is_symmetric + self.nodes_to_exclude = set(nodes_to_exclude) + + @staticmethod + def __get_initializer(name, graph_path: List[GraphProto]) -> Tuple[TensorProto, GraphProto]: + for gid in range(len(graph_path) - 1, -1, -1): + graph = graph_path[gid] + for tensor in graph.initializer: + if tensor.name == name: + return tensor, graph + return None, None + + def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: + """4b quantize fp32 weight to a blob""" + + if len(fp32weight.shape) != 2: + raise ValueError("Current int4 block quantization only supports 2D tensors!") + rows, cols = fp32weight.shape + + block_size = self.block_size + blob_size = block_size // 2 + k_blocks = (rows + block_size - 1) // block_size + padded_rows = k_blocks * block_size + pad_len = padded_rows - rows + if pad_len > 0: + fp32weight = np.pad(fp32weight, ((0, pad_len), (0, 0)), "constant") + + # block wise quantization, each block comes from a single column + packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") + scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype) + zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8") + quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.is_symmetric) + + return (packed, scales, zero_point) + + def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto]) -> NodeProto: + """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" + + if node.op_type != "MatMul": + return node # only care about MatMul for now + + logger.info(f"start to quantize {node.name} ...") + if node.name in self.nodes_to_exclude: + logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...") + return node + + inputB = node.input[1] # noqa: N806 + B, Bs_graph = MatMul4BitsQuantizer.__get_initializer(inputB, graph_stack) # noqa: N806 + if B is None: + logger.info("MatMul doesn't have const weight. Skip to quantize") + return node # only care about constant weight + + B_array = onnx.numpy_helper.to_array(B) # noqa: N806 + if len(B_array.shape) != 2: + logger.info("MatMul weight is not 2D. Skip to quantize") + return node # can only process 2-D matrix + + packed, scales, zero_points = self.int4_block_quant(B_array) + B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806 + B_quant.name = B.name + "_Q4" + for input in Bs_graph.input: + if input.name == inputB: + Bs_graph.input.remove(input) + break + + scales_tensor = onnx.numpy_helper.from_array(scales) + scales_tensor.name = B.name + "_scales" + Bs_graph.initializer.extend([B_quant, scales_tensor]) + + input_names = [node.input[0], B_quant.name, scales_tensor.name] + if not self.is_symmetric: + zp_tensor = onnx.numpy_helper.from_array(zero_points) + zp_tensor.name = B.name + "_zero_points" + Bs_graph.initializer.extend([zp_tensor]) + input_names.append(zp_tensor.name) + + kwargs = {} + rows, cols = B_array.shape + kwargs["K"] = rows + kwargs["N"] = cols + kwargs["bits"] = 4 + kwargs["block_size"] = self.block_size + + matmul_q4_node = onnx.helper.make_node( + "MatMulNBits", + inputs=input_names, + outputs=[node.output[0]], + name=node.name + "_Q4" if node.name else "", + domain="com.microsoft", + **kwargs, + ) + + logger.info(f"complete quantization of {node.name} ...") + + return matmul_q4_node + + def _process_subgraph(self, graph_stack: List[GraphProto]): + new_nodes = [] + graph = graph_stack[-1] + + for node in graph.node: + graph_attrs = [ + attr + for attr in node.attribute + if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS + ] + if len(graph_attrs): + kwargs = {} + for attr in node.attribute: + if attr.type == onnx.AttributeProto.GRAPH: + # recursive call to take care of sub-graph + graph_stack.append(attr.g) + kv = {attr.name: self._process_subgraph(graph_stack)} + elif attr.type == onnx.AttributeProto.GRAPHS: + value = [] + for subgraph in attr.graphs: + # recursive call to take care of sub-graph + graph_stack.append(subgraph) + value.extend([self._process_subgraph(graph_stack)]) + kv = {attr.name: value} + else: + kv = attribute_to_kwarg(attr) + kwargs.update(kv) + node = onnx.helper.make_node( # noqa: PLW2901 + node.op_type, node.input, node.output, name=node.name, **kwargs + ) + + new_nodes.append(self._q4_matmul_node_weight(node, graph_stack)) + + graph.ClearField("node") + graph.node.extend(new_nodes) + graph_stack.pop() + return graph + + def process(self): + # use a stack to keep track of sub-graphs + graph_stack = [self.model.graph()] + opset_import = self.model.opset_import() + + has_ms_domain = False + for opset in opset_import: + if opset.domain == "com.microsoft": + has_ms_domain = True + if not has_ms_domain: + opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)]) + + self._process_subgraph(graph_stack) + self.model.clean_initializers() + + +def parse_args(): + parser = argparse.ArgumentParser( + description="""Blockwise int4 quantization for MatMul 2D weight matrices. + +A weight matrix is partitioned into into blocks, where each block is a +continguous subset inside each column. Each block is quantized into a +set of 4b integers with a scaling factor and an optional offset. +""" + ) + + parser.add_argument("--input_model", required=True, help="Path to the input model file") + parser.add_argument("--output_model", required=True, help="Path to the output model file") + parser.add_argument("--block_size", required=False, default=32) + parser.add_argument( + "--symmetric", required=False, default=True, help="Indicate whether to quantize the model symmetrically" + ) + parser.add_argument("-v", "--verbose", required=False, action="store_true") + parser.set_defaults(verbose=False) + parser.add_argument( + "--nodes_to_exclude", + nargs="+", + type=str, + required=False, + default=[], + help="Specify the nodes to be excluded from quantization with node names", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + if args.verbose: + logger.setLevel(logging.DEBUG) + + input_model_path = args.input_model + output_model_path = args.output_model + + if os.path.exists(output_model_path): + logger.error(f"file {output_model_path} already exists") + raise Exception(f"file {output_model_path} already exists") + + model = onnx.load(input_model_path) + quant = MatMul4BitsQuantizer(model, args.block_size, args.symmetric, nodes_to_exclude=args.nodes_to_exclude) + quant.process() + quant.model.save_model_to_file(output_model_path, True) diff --git a/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py b/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py new file mode 100644 index 0000000000..951746a089 --- /dev/null +++ b/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py @@ -0,0 +1,240 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import argparse +import logging +import os +from typing import List, Tuple + +import numpy as np +import numpy.typing as npt +import onnx +from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto + +from onnxruntime.capi._pybind_state import quantize_matmul_bnb4 + +from .onnx_model import ONNXModel +from .quant_utils import attribute_to_kwarg + +logger = logging.getLogger(__name__) + + +class MatMulBnb4Quantizer: + """Perform 4b quantization of constant MatMul weights using FP4 or NF4 data type""" + + ################## + # quantization types, must be consistent with native code type + # Bnb_DataType_t defined in blockwise_quant_block_bnb4.h + + # 4b floating point with bias of 3 + FP4 = 0 + + # 4b NormalFloat + NF4 = 1 + + def __init__(self, model: ModelProto, quant_type: int, block_size: int, nodes_to_exclude=None): + nodes_to_exclude = nodes_to_exclude or [] + assert quant_type in [MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4] + self.model = ONNXModel(model) + self.quant_type = quant_type + self.block_size = block_size + self.nodes_to_exclude = set(nodes_to_exclude) + + @staticmethod + def __get_initializer(name, graph_path: List[GraphProto]) -> Tuple[TensorProto, GraphProto]: + for gid in range(len(graph_path) - 1, -1, -1): + graph = graph_path[gid] + for tensor in graph.initializer: + if tensor.name == name: + return tensor, graph + return None, None + + def bnb4_block_quant(self, fpweight: npt.ArrayLike) -> np.ndarray: + """4b quantize fp32/fp16 weight""" + + if len(fpweight.shape) != 2: + raise ValueError("Current bnb4 block quantization only supports 2D tensors!") + # need to copy since the transposed weight still has the original memory layout + # Linear4bit quantizes its weight data which is the transposed weight + fpweight_t = fpweight.transpose().copy() + + rows, cols = fpweight.shape + numel = rows * cols + block_size = self.block_size + num_blocks = (numel + block_size - 1) // block_size + quantized_numel = (numel + 1) // 2 + + packed = np.zeros(quantized_numel, dtype="uint8") + absmax = np.zeros(num_blocks, dtype=fpweight.dtype) + # block wise quantization, fpweight_t is flattened and divided into blocks + quantize_matmul_bnb4(packed, fpweight_t, absmax, block_size, self.quant_type, cols, rows) + + return (packed, absmax) + + def _bnb4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto]) -> NodeProto: + """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" + + if node.op_type != "MatMul": + return node # only care about MatMul for now + + logger.debug(f"start to quantize {node.name} ...") + if node.name in self.nodes_to_exclude: + logger.debug(f"exclude to quantize {node.name} as specified by nodes_to_exclude...") + return node + + inputB = node.input[1] # noqa: N806 + B, Bs_graph = MatMulBnb4Quantizer.__get_initializer(inputB, graph_stack) # noqa: N806 + if B is None: + logger.debug("MatMul doesn't have const weight. Skip to quantize") + return node # only care about constant weight + + B_array = onnx.numpy_helper.to_array(B) # noqa: N806 + if len(B_array.shape) != 2: + logger.debug("MatMul weight is not 2D. Skip to quantize") + return node # can only process 2-D matrix + + packed, absmax = self.bnb4_block_quant(B_array) + B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806 + B_quant.name = B.name + "_Bnb4" + for input in Bs_graph.input: + if input.name == inputB: + Bs_graph.input.remove(input) + break + + absmax_tensor = onnx.numpy_helper.from_array(absmax) + absmax_tensor.name = B.name + "_absmax" + + Bs_graph.initializer.extend([B_quant, absmax_tensor]) + + kwargs = {} + rows, cols = B_array.shape + kwargs["K"] = rows + kwargs["N"] = cols + kwargs["block_size"] = self.block_size + kwargs["quant_type"] = self.quant_type + + matmul_bnb4_node = onnx.helper.make_node( + "MatMulBnb4", + inputs=[node.input[0], B_quant.name, absmax_tensor.name], + outputs=[node.output[0]], + name=node.name + "_Bnb4" if node.name else "", + domain="com.microsoft", + **kwargs, + ) + + logger.debug(f"complete quantization of {node.name} ...") + + return matmul_bnb4_node + + def _process_subgraph(self, graph_stack: List[GraphProto]): + new_nodes = [] + graph = graph_stack[-1] + + for node in graph.node: + graph_attrs = [ + attr + for attr in node.attribute + if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS + ] + if len(graph_attrs): + kwargs = {} + for attr in node.attribute: + if attr.type == onnx.AttributeProto.GRAPH: + # recursive call to take care of sub-graph + graph_stack.append(attr.g) + kv = {attr.name: self._process_subgraph(graph_stack)} + elif attr.type == onnx.AttributeProto.GRAPHS: + value = [] + for subgraph in attr.graphs: + # recursive call to take care of sub-graph + graph_stack.append(subgraph) + value.extend([self._process_subgraph(graph_stack)]) + kv = {attr.name: value} + else: + kv = attribute_to_kwarg(attr) + kwargs.update(kv) + node = onnx.helper.make_node( # noqa: PLW2901 + node.op_type, node.input, node.output, name=node.name, **kwargs + ) + + new_nodes.append(self._bnb4_matmul_node_weight(node, graph_stack)) + + graph.ClearField("node") + graph.node.extend(new_nodes) + graph_stack.pop() + return graph + + def process(self): + # use a stack to keep track of sub-graphs + graph_stack = [self.model.graph()] + opset_import = self.model.opset_import() + + has_ms_domain = False + for opset in opset_import: + if opset.domain == "com.microsoft": + has_ms_domain = True + if not has_ms_domain: + opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)]) + + self._process_subgraph(graph_stack) + self.model.clean_initializers() + + +def parse_args(): + parser = argparse.ArgumentParser( + description="""Blockwise FP4/NF4 quantization for MatMul 2D weight matrices. + +A weight matrix is partitioned into blocks, where each block is a contiguous +subset inside the flattened transposed weight matrix. Each block is quantized +into a set of 4b integers with an absolute value scaling factor. +""" + ) + + parser.add_argument("--input_model", required=True, help="Path to the input model file") + parser.add_argument("--output_model", required=True, help="Path to the output model file") + parser.add_argument( + "--quant_type", + required=False, + default=1, + options=[MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4], + help="Quantization data type. 0: FP4, 1: NF4", + ) + parser.add_argument( + "--block_size", + required=False, + default=64, + description="Block size for blockwise quantization. Note: bnb.nn.Linear4bit only uses block_size=64", + ) + parser.add_argument("-v", "--verbose", required=False, action="store_true") + parser.set_defaults(verbose=False) + parser.add_argument( + "--nodes_to_exclude", + nargs="+", + type=str, + required=False, + default=[], + help="Specify the nodes to be excluded from quantization with node names", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + if args.verbose: + logger.setLevel(logging.DEBUG) + + input_model_path = args.input_model + output_model_path = args.output_model + + if os.path.exists(output_model_path): + logger.error(f"file {output_model_path} already exists") + raise Exception(f"file {output_model_path} already exists") + + model = onnx.load(input_model_path) + quant = MatMulBnb4Quantizer(model, args.quant_type, args.block_size, nodes_to_exclude=args.nodes_to_exclude) + quant.process() + quant.model.save_model_to_file(output_model_path, True) diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 924d4c72b6..bb968d660c 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -112,8 +112,8 @@ def __init__( False if "ActivationSymmetric" not in self.extra_options else self.extra_options["ActivationSymmetric"] ) - self.activation_qType = activation_qType.tensor_type - self.weight_qType = weight_qType.tensor_type + self.activation_qType = getattr(activation_qType, "tensor_type", activation_qType) + self.weight_qType = getattr(weight_qType, "tensor_type", weight_qType) """ Dictionary specifying the min and max values for tensors. It has following format: { diff --git a/onnxruntime/python/tools/quantization/operators/conv.py b/onnxruntime/python/tools/quantization/operators/conv.py index d23459b478..23f9eaf4b0 100644 --- a/onnxruntime/python/tools/quantization/operators/conv.py +++ b/onnxruntime/python/tools/quantization/operators/conv.py @@ -157,7 +157,7 @@ def quantize(self): nodes, ) = self.quantizer.quantize_activation(node, [0]) quant_weight_tuple = self.quantizer.quantize_weight_per_channel( - node.input[1], onnx_proto.TensorProto.INT8, 0 + node.input[1], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType? ) quantized_input_names.append(quant_weight_tuple[0]) zero_point_names.append(quant_weight_tuple[1]) diff --git a/onnxruntime/python/tools/quantization/operators/convtranspose2d.py b/onnxruntime/python/tools/quantization/operators/convtranspose2d.py new file mode 100644 index 0000000000..3eb93b0bae --- /dev/null +++ b/onnxruntime/python/tools/quantization/operators/convtranspose2d.py @@ -0,0 +1,115 @@ +import numpy as np +import onnx +from onnx import onnx_pb as onnx_proto + +from ..quant_utils import TENSOR_NAME_QUANT_SUFFIX, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain +from .base_operator import QuantOperatorBase + + +class QLinearConvTranspose(QuantOperatorBase): + def __init__(self, onnx_quantizer, onnx_node): + super().__init__(onnx_quantizer, onnx_node) + + def quantize(self): + node = self.node + assert node.op_type == "ConvTranspose" + print(f"Custom quantization code for {node.op_type}") + + ( + data_found, + output_scale_name, + output_zp_name, + _, + _, + ) = self.quantizer._get_quantization_params(node.output[0]) + + if self.quantizer.is_input_a_initializer(node.input[1]) and self.quantizer.is_per_channel(): + ( + quantized_input_names, + zero_point_names, + scale_names, + nodes, + ) = self.quantizer.quantize_activation(node, [0]) + quant_weight_tuple = self.quantizer.quantize_weight_per_channel( + node.input[1], onnx_proto.TensorProto.INT8, 0 + ) + quantized_input_names.append(quant_weight_tuple[0]) + zero_point_names.append(quant_weight_tuple[1]) + scale_names.append(quant_weight_tuple[2]) + else: + ( + quantized_input_names, + zero_point_names, + scale_names, + nodes, + ) = self.quantizer.quantize_activation(node, [0]) + + ( + quantized_input_names_weight, + zero_point_names_weight, + scale_names_weight, + nodes_weight, + ) = self.quantizer.quantize_weight(node, [1], reduce_range=self.quantizer.reduce_range) + quantized_input_names.extend(quantized_input_names_weight) + zero_point_names.extend(zero_point_names_weight) + scale_names.extend(scale_names_weight) + nodes.extend(nodes_weight) + + if not data_found or quantized_input_names is None: + return super().quantize() + + quantized_bias_name = "" + bias_present = False + if len(node.input) == 3: + quantized_bias_name = self.quantizer.quantize_bias_static(node.input[2], node.input[0], node.input[1]) + bias_present = True + + qlinear_conv_output = node.output[0] + TENSOR_NAME_QUANT_SUFFIX + qlinear_conv_name = qlinear_conv_name = node.name + "_quant" if node.name != "" else "" + + kwargs = {} + kwargs["domain"] = ms_domain + for attribute in node.attribute: + kwargs.update(attribute_to_kwarg(attribute)) + qlinear_conv_inputs = [] + + # Input 0 + qlinear_conv_inputs.append(quantized_input_names[0]) + qlinear_conv_inputs.append(scale_names[0]) + qlinear_conv_inputs.append(zero_point_names[0]) + + # Input 1 + qlinear_conv_inputs.append(quantized_input_names[1]) + qlinear_conv_inputs.append(scale_names[1]) + qlinear_conv_inputs.append(zero_point_names[1]) + + # Output + qlinear_conv_inputs.append(output_scale_name) + qlinear_conv_inputs.append(output_zp_name) + + if bias_present: + qlinear_conv_inputs.append(quantized_bias_name) + + qlinear_conv_node = onnx.helper.make_node( + "QLinearConvTranspose", qlinear_conv_inputs, [qlinear_conv_output], qlinear_conv_name, **kwargs + ) + + # Add type information for the quantized node, as onnxruntime cannot infer the QLinearConvTranspose node currently + if node.output[0] in self.quantizer.value_infos: + op_shape = self.quantizer.value_infos[node.output[0]].type.tensor_type.shape.dim + op_shape = [it.dim_value for it in op_shape] + self.quantizer.model.graph().value_info.extend([onnx.helper.make_tensor_value_info(qlinear_conv_node.output[0], onnx.TensorProto.INT8, shape=op_shape)]) + + nodes.append(qlinear_conv_node) + + # Create an entry for this quantized value + q_output = QuantizedValue( + node.output[0], + qlinear_conv_output, + output_scale_name, + output_zp_name, + QuantizedValueType.Input, + ) + self.quantizer.quantized_value_map[node.output[0]] = q_output + + self.quantizer.new_nodes += nodes diff --git a/onnxruntime/python/tools/quantization/operators/direct_q8.py b/onnxruntime/python/tools/quantization/operators/direct_q8.py index c14532b96a..9e2838bd78 100644 --- a/onnxruntime/python/tools/quantization/operators/direct_q8.py +++ b/onnxruntime/python/tools/quantization/operators/direct_q8.py @@ -35,7 +35,10 @@ def quantize(self): else: # Force quantize those ops if possible, use exclude node list if this is not you want - if not self.quantizer.is_valid_quantize_weight(node.input[0]): + # TODO: this check seems overly restrictive, for now add an exception to allow MaxPool + # forced quantization but invesigate full removing the check + if not self.quantizer.is_valid_quantize_weight(node.input[0]) and\ + node.op_type != 'MaxPool': super().quantize() return diff --git a/onnxruntime/python/tools/quantization/operators/gavgpool.py b/onnxruntime/python/tools/quantization/operators/gavgpool.py index ceeead6846..6442e123a1 100644 --- a/onnxruntime/python/tools/quantization/operators/gavgpool.py +++ b/onnxruntime/python/tools/quantization/operators/gavgpool.py @@ -11,10 +11,15 @@ def __init__(self, onnx_quantizer, onnx_node): def quantize(self): node = self.node assert node.op_type == "GlobalAveragePool" - - # If input to this node is not quantized then keep this node. + nodes = [] + # If input to this node is not quantized then force the quantization. if node.input[0] not in self.quantizer.quantized_value_map: - return super().quantize() + ( + quantized_input_names, + zero_point_names, + scale_names, + nodes, + ) = self.quantizer.quantize_activation(node, [0]) quantized_input_value = self.quantizer.quantized_value_map[node.input[0]] @@ -59,4 +64,5 @@ def quantize(self): qnode_name, **kwargs, ) - self.quantizer.new_nodes += [qnode] + nodes.append(qnode) + self.quantizer.new_nodes += nodes diff --git a/onnxruntime/python/tools/quantization/operators/lstm.py b/onnxruntime/python/tools/quantization/operators/lstm.py index 7e91f9b76c..90a52cb528 100644 --- a/onnxruntime/python/tools/quantization/operators/lstm.py +++ b/onnxruntime/python/tools/quantization/operators/lstm.py @@ -47,10 +47,10 @@ def quantize(self): R.dims[0] = R_num_dir * R_4_hidden_size quant_input_weight_tuple = self.quantizer.quantize_weight_per_channel( - node.input[1], onnx_proto.TensorProto.INT8, 0 + node.input[1], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType? ) quant_recurrent_weight_tuple = self.quantizer.quantize_weight_per_channel( - node.input[2], onnx_proto.TensorProto.INT8, 0 + node.input[2], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType? ) W_quant_weight = model.get_initializer(quant_input_weight_tuple[0]) # noqa: N806 diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index f87a9d8228..a03f6431fb 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -266,7 +266,13 @@ def _add_qdq_pair_for_initializer(self, weight_proto, tensor_type, axis=None): raise ValueError("Per-Channel support with QDQ format requires onnx opset version 13 or above.") q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel( weight_name, - self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType, + # Quantization type is forced to be TensorProto.INT8. + # when the expected value would be (see below) + # self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType. + # QLinearConv expects to have a unique value for all channels. + # This code does not enforce that but it is necessarily the case when the + # quantization is symmetric (as for INT8). + onnx_proto.TensorProto.INT8, axis, keep_float_weight=self.add_qdq_pair_to_weight, ) diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py index e8bcf9107c..3342f97f4e 100644 --- a/onnxruntime/python/tools/quantization/registry.py +++ b/onnxruntime/python/tools/quantization/registry.py @@ -5,6 +5,7 @@ from .operators.binary_op import QLinearBinaryOp from .operators.concat import QLinearConcat from .operators.conv import ConvInteger, QDQConv, QLinearConv +from .operators.convtranspose2d import QLinearConvTranspose from .operators.direct_q8 import Direct8BitOp, QDQDirect8BitOp from .operators.embed_layernorm import EmbedLayerNormalizationQuant from .operators.gather import GatherQuant, QDQGather @@ -40,6 +41,7 @@ QLinearOpsRegistry = { "ArgMax": QArgMax, "Conv": QLinearConv, + "ConvTranspose": QLinearConvTranspose, "Gemm": QLinearGemm, "MatMul": QLinearMatMul, "Add": QLinearBinaryOp, @@ -53,6 +55,7 @@ "Split": QSplit, "Pad": QPad, "Reshape": Direct8BitOp, + "Flatten": Direct8BitOp, "Squeeze": Direct8BitOp, "Unsqueeze": Direct8BitOp, "Resize": QResize, diff --git a/onnxruntime/python/tools/quantization/shape_inference.py b/onnxruntime/python/tools/quantization/shape_inference.py index eff3dc0bcd..b7d4726610 100644 --- a/onnxruntime/python/tools/quantization/shape_inference.py +++ b/onnxruntime/python/tools/quantization/shape_inference.py @@ -99,7 +99,10 @@ def quant_pre_process( sess_option = onnxruntime.SessionOptions() sess_option.optimized_model_filepath = opt_model_path sess_option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC - _ = onnxruntime.InferenceSession(input_model_path, sess_option, providers=["CPUExecutionProvider"]) + sess = onnxruntime.InferenceSession(input_model_path, sess_option, providers=["CPUExecutionProvider"]) + # Close the session to avoid the cleanup error on Windows for temp folders + # https://github.com/microsoft/onnxruntime/issues/17627 + del sess except Exception: logger.error( "ONNX Runtime Model Optimization Failed! Consider rerun with option `--skip_optimization'." diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index f1ae93cfc1..a372446ede 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -3,6 +3,7 @@ # -*- coding: UTF-8 -*- import argparse +import copy import logging import numpy as np @@ -147,12 +148,15 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "GatherElements": self._infer_GatherElements, "GatherND": self._infer_GatherND, "Identity": self._pass_on_shape_and_type, + "AllReduce": self._pass_on_shape_and_type, "If": self._infer_If, "Loop": self._infer_Loop, "MatMul": self._infer_MatMul, "MatMulInteger16": self._infer_MatMulInteger, "MaxPool": self._infer_Pool, "Max": self._infer_symbolic_compute_ops, + "MemcpyFromHost": self._pass_on_shape_and_type, + "MemcpyToHost": self._pass_on_shape_and_type, "Min": self._infer_symbolic_compute_ops, "Mul": self._infer_symbolic_compute_ops, "NonMaxSuppression": self._infer_NonMaxSuppression, @@ -199,6 +203,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "Gelu": self._infer_Gelu, "GemmFastGelu": self._infer_GemmFastGelu, "GroupNorm": self._infer_GroupNorm, + "SkipGroupNorm": self._infer_SkipGroupNorm, "LayerNormalization": self._infer_LayerNormalization, "LongformerAttention": self._infer_LongformerAttention, "MultiHeadAttention": self._infer_MultiHeadAttention, @@ -206,12 +211,27 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "PackedAttention": self._infer_PackedAttention, "PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention, "PythonOp": self._infer_PythonOp, + "QuickGelu": self._infer_FastGelu, "RelativePositionBias": self._infer_RelativePositionBias, "RemovePadding": self._infer_RemovePadding, "RestorePadding": self._infer_RestorePadding, + "RotaryEmbedding": self._infer_RotaryEmbedding, "SimplifiedLayerNormalization": self._infer_LayerNormalization, "SkipLayerNormalization": self._infer_SkipLayerNormalization, "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization, + # QLinear ops + "QLinearConcat": self._infer_qlinear_concat, + "QGemm": self._infer_qgemm, + "QLinearAdd": self._infer_qlinear_binary_op, + "QLinearMul": self._infer_qlinear_binary_op, + "QLinearLeakyRelu": self._infer_qlinear_unary_op, + "QLinearConvTranspose": self._infer_qlinear_binary_op, + "QLinearSigmoid": self._infer_qlinear_unary_op, + "QLinearSoftmax": self._infer_qlinear_unary_op, + "QLinearGlobalAveragePool": self._infer_qlinear_unary_op, + "QLinearAveragePool": self._infer_qlinear_unary_op, + # Quadric custom operators + "QuadricCustomOp": self._infer_custom_op, } self.aten_op_dispatcher_ = { "embedding": self._infer_Gather, @@ -230,7 +250,6 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "upsample_nearest1d": self._infer_aten_upsample, "upsample_nearest2d": self._infer_aten_upsample, "upsample_nearest3d": self._infer_aten_upsample, - "upsample_bilinear2d": self._infer_aten_upsample, } self.run_ = True self.suggested_merge_ = {} @@ -440,6 +459,7 @@ def _onnx_infer_single_node(self, node): "If", "Loop", "Scan", + "QuadricCustomOp", "SplitToSequence", "ZipMap", # contrib ops "Attention", @@ -463,6 +483,8 @@ def _onnx_infer_single_node(self, node): "BiasSplitGelu", "BiasAdd", "NhwcConv", + "QuickGelu", + "RotaryEmbedding", ] if not skip_infer: @@ -897,7 +919,85 @@ def _infer_Concat(self, node): # noqa: N802 ) ) - def _infer_ConcatFromSequence(self, node): # noqa: N802 + def _filter_node_inputs(self, node, inp_list): + # Create a copy of the node, remove the additional arguments of the quantized version + # and call the FP32 version of the inference + new_node = copy.deepcopy(node) + for idx in range(len(new_node.input) - 1, -1, -1): + if idx not in inp_list: + del new_node.input[idx] + return new_node + + def _qlinear_onnx_shape_infer(self, node, prequant_input_idx): + # Remove the quantization specific input and + # change the node type to match the unquantized + # node, then use ONNX to infer the output type + new_node = self._filter_node_inputs(node, prequant_input_idx) + new_node.op_type = new_node.op_type.replace("QLinear", "") + new_node.domain = "" + self._onnx_infer_single_node(new_node) + + def _infer_qlinear_unary_op(self, node): + # For qlinear unary operators the input order is + # [inp, inp_scale, inp_zp, out_scale, out_zp] and + # we want to preserve [inp] for shape inference + # https://github.com/quadric-io/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearSigmoid + prequant_input_idx = [0] + self._qlinear_onnx_shape_infer(node, prequant_input_idx) + + def _infer_qlinear_binary_op(self, node): + # For qlinear binary operators the input order is + # [inp_0, inp_0_scale, inp_0_zp, inp_1, inp_1_scale, inp_1_zp, out_scale, out_zp] + # and we want to preserve [inp_0, inp_1] + # This also applies to operators where the shape is determined by the input and weight shapes, such as + # QLinearConv and QLinearConvTranspose + # https://github.com/quadric-io/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearAdd + prequant_input_idx = [0, 3] + self._qlinear_onnx_shape_infer(node, prequant_input_idx) + + def _infer_qlinear_concat(self, node): + # The inputs for QLinearConcat are in the format + # [y_scale, y_zp, inp_0, inp_0_scale, inp_0_zp, inp_1, inp_1_sc...] + # https://github.com/quadric-io/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearConcat + # After removing the quantization params we should be left with + # [inp_0, inp_1, ...] + num_prequant_inputs = (len(node.input) - 2) // 3 + prequant_input_idx = [idx * 3 + 2 for idx in range(num_prequant_inputs)] + self._qlinear_onnx_shape_infer(node, prequant_input_idx) + + def _infer_qgemm(self, node): + # QGemm has a different naming convention compared to the rest of the + # QLinearOps, treat it separately. + # The inputs for QLinearConcat are in the format + # [A, a_scale, a_zp, B, b_scale, b_zp, c, y_scale, y_zp] + # https://github.com/quadric-io/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QGemm + # After removing the quantization params we should be left with [A, B] + prequant_input_idx = [0, 3] + new_node = self._filter_node_inputs(node, prequant_input_idx) + new_node.op_type = "Gemm" + new_node.domain = "" + self._onnx_infer_single_node(new_node) + + def _infer_custom_op(self, node): + # For the CCL custom operators the shape and dtype of the output are present in + # the attributes and can be used to directly create the value info + attr_map = {n.name: n for n in list(node.attribute)} + assert "shape" in attr_map and "elem_type" in attr_map, "Custom op output type not found" + if len(node.output) > 1: + for i, out in enumerate(node.output): + vi = self.known_vi_[out] + vi.CopyFrom( + helper.make_tensor_value_info( + out, + attr_map["elem_type"].ints[i], + attr_map["shape"].tensors[i].int32_data, + ) + ) + else: + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], attr_map["elem_type"].i, attr_map["shape"].ints)) + + def _infer_ConcatFromSequence(self, node): seq_shape = self._get_shape(node, 0) new_axis = 1 if get_attribute(node, "new_axis") else 0 axis = handle_negative_axis(get_attribute(node, "axis"), len(seq_shape) + new_axis) @@ -2308,6 +2408,9 @@ def _infer_FastGelu(self, node): # noqa: N802 def _infer_Gelu(self, node): # noqa: N802 self._propagate_shape_and_type(node) + def _infer_QuickGelu(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + def _infer_GemmFastGelu(self, node): # noqa: N802 self._compute_matmul_shape(node) @@ -2366,6 +2469,11 @@ def _infer_SkipLayerNormalization(self, node): # noqa: N802 def _infer_GroupNorm(self, node): # noqa: N802 self._propagate_shape_and_type(node) + def _infer_SkipGroupNorm(self, node): # noqa: N802 + self._propagate_shape_and_type(node, 0, 0) + if len(node.output) > 1: + self._propagate_shape_and_type(node, 0, 1) + def _infer_BiasSplitGelu(self, node): # noqa: N802 input_shape = self._get_shape(node, 0) bias_shape = self._get_shape(node, 1) @@ -2379,6 +2487,19 @@ def _infer_BiasSplitGelu(self, node): # noqa: N802 def _infer_BiasAdd(self, node): # noqa: N802 self._propagate_shape_and_type(node) + def _infer_RotaryEmbedding(self, node): # noqa: N802 + if len(node.output) == 1: + self._propagate_shape_and_type(node) + elif len(node.output) == 2: + # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions` + self._propagate_shape_and_type(node, input_index=1, output_index=0) + self._propagate_shape_and_type(node, input_index=0, output_index=1) # true output + elif len(node.output) == 3: + # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions` + self._propagate_shape_and_type(node, input_index=1, output_index=0) + self._propagate_shape_and_type(node, input_index=1, output_index=1) + self._propagate_shape_and_type(node, input_index=0, output_index=2) # true output + def _infer_PythonOp(self, node): # noqa: N802 output_tensor_types = get_attribute(node, "output_tensor_types") assert output_tensor_types @@ -2493,6 +2614,10 @@ def get_prereq(node): get_attribute(node, "then_branch"), get_attribute(node, "else_branch"), ] + elif node.op_type == "QuadricCustomOp": + # Should have a subgraph, but allow for cases where it's not there + subgraph = get_attribute(node, "sub_graph") + subgraphs = [subgraph] if subgraph else [] elif node.op_type in ["Loop", "Scan"]: subgraphs = [get_attribute(node, "body")] for g in subgraphs: @@ -2584,12 +2709,19 @@ def get_prereq(node): self._check_merged_dims(in_dims, allow_broadcast=True) for i_o in range(len(node.output)): - # Special case: We do not care about the training related - # outputs of SkipLayerNormalization + # Special cases: + # 1) We do not care about the training related outputs of SkipLayerNormalization + # 2) We do not care about the extraneous constant outputs in RotaryEmbedding because + # the RotaryEmbedding op created during export can be replaced by the RotaryEmbedding + # contrib op if ( node.op_type == "SkipLayerNormalization" or node.op_type == "SkipSimplifiedLayerNormalization" ) and i_o in [1, 2]: continue + if node.op_type == "RotaryEmbedding" and len(node.output) > 1: + # Skip symbolic shape inference for RotaryEmbedding functions that have extraneous outputs + # generated by `export_modules_as_functions` + continue vi = self.known_vi_[node.output[i_o]] out_type = vi.type @@ -2751,13 +2883,13 @@ def get_prereq(node): if i in self.known_vi_: logger.debug(self.known_vi_[i]) else: - logger.debug(f"not in knwon_vi_ for {i}") + logger.debug(f"not in known_vi_ for {i}") logger.debug("node outputs:") for o in node.output: if o in self.known_vi_: logger.debug(self.known_vi_[o]) else: - logger.debug(f"not in knwon_vi_ for {o}") + logger.debug(f"not in known_vi_ for {o}") if self.auto_merge_ and not out_type_undefined: logger.debug("Merging: " + str(self.suggested_merge_)) return False diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index 5fa64d1bc0..b6f7a44450 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -8,7 +8,10 @@ import logging import os import random +import sys +import time import timeit +from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from datetime import datetime from enum import Enum @@ -30,6 +33,7 @@ class Precision(Enum): FLOAT32 = "fp32" FLOAT16 = "fp16" INT8 = "int8" + INT4 = "int4" def __str__(self): return self.value @@ -170,7 +174,7 @@ def prepare_environment(cache_dir, output_dir, use_gpu, provider=None): logger.info(f"PyTorch Version:{torch.__version__}") logger.info(f"Transformers Version:{transformers.__version__}") - logger.info(f"Onnxruntime Version:{onnxruntime.__version__}") + logger.info(f"OnnxRuntime Version:{onnxruntime.__version__}") # Support three major versions of PyTorch and OnnxRuntime, and up to 9 months of transformers. assert version.parse(torch.__version__) >= version.parse("1.10.0") @@ -439,76 +443,141 @@ def get_gpu_info() -> Optional[List[Dict[str, Any]]]: return None -def measure_memory(is_gpu, func): - class MemoryMonitor: - def __init__(self, keep_measuring=True): - self.keep_measuring = keep_measuring +class MemoryMonitor(ABC): + def __init__(self, keep_measuring=True): + self.keep_measuring = keep_measuring - def measure_cpu_usage(self): - import psutil + def measure_cpu_usage(self): + import psutil - max_usage = 0 + max_usage = 0 + while True: + max_usage = max(max_usage, psutil.Process(os.getpid()).memory_info().rss / 1024**2) + sleep(0.005) # 5ms + if not self.keep_measuring: + break + return max_usage + + @abstractmethod + def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]: + raise NotImplementedError() + + +class CudaMemoryMonitor(MemoryMonitor): + def __init__(self, keep_measuring=True): + super().__init__(keep_measuring) + + def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]: + from py3nvml.py3nvml import ( + NVMLError, + nvmlDeviceGetCount, + nvmlDeviceGetHandleByIndex, + nvmlDeviceGetMemoryInfo, + nvmlDeviceGetName, + nvmlInit, + nvmlShutdown, + ) + + max_gpu_usage = [] + gpu_name = [] + try: + nvmlInit() + device_count = nvmlDeviceGetCount() + if not isinstance(device_count, int): + logger.error(f"nvmlDeviceGetCount result is not integer: {device_count}") + return None + + max_gpu_usage = [0 for i in range(device_count)] + gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)] while True: - max_usage = max(max_usage, psutil.Process(os.getpid()).memory_info().rss / 1024**2) + for i in range(device_count): + info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i)) + if isinstance(info, str): + logger.error(f"nvmlDeviceGetMemoryInfo returns str: {info}") + return None + max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2) sleep(0.005) # 5ms if not self.keep_measuring: break - return max_usage - - def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]: - from py3nvml.py3nvml import ( - NVMLError, - nvmlDeviceGetCount, - nvmlDeviceGetHandleByIndex, - nvmlDeviceGetMemoryInfo, - nvmlDeviceGetName, - nvmlInit, - nvmlShutdown, - ) + nvmlShutdown() + return [ + { + "device_id": i, + "name": gpu_name[i], + "max_used_MB": max_gpu_usage[i], + } + for i in range(device_count) + ] + except NVMLError as error: + logger.error("Error fetching GPU information using nvml: %s", error) + return None - max_gpu_usage = [] - gpu_name = [] - try: - nvmlInit() - device_count = nvmlDeviceGetCount() - if not isinstance(device_count, int): - logger.error(f"nvmlDeviceGetCount result is not integer: {device_count}") - return None - - max_gpu_usage = [0 for i in range(device_count)] - gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)] - while True: - for i in range(device_count): - info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i)) - if isinstance(info, str): - logger.error(f"nvmlDeviceGetMemoryInfo returns str: {info}") - return None - max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2) - sleep(0.005) # 5ms - if not self.keep_measuring: - break - nvmlShutdown() - return [ - { - "device_id": i, - "name": gpu_name[i], - "max_used_MB": max_gpu_usage[i], - } - for i in range(device_count) - ] - except NVMLError as error: - logger.error("Error fetching GPU information using nvml: %s", error) - return None - monitor = MemoryMonitor(False) +class RocmMemoryMonitor(MemoryMonitor): + def __init__(self, keep_measuring=True): + super().__init__(keep_measuring) + rocm_smi_path = "/opt/rocm/libexec/rocm_smi" + if os.path.exists(rocm_smi_path): + if rocm_smi_path not in sys.path: + sys.path.append(rocm_smi_path) + try: + import rocm_smi + + self.rocm_smi = rocm_smi + self.rocm_smi.initializeRsmi() + except ImportError: + self.rocm_smi = None + + def get_used_memory(self, dev): + if self.rocm_smi is None: + return -1 + return self.rocm_smi.getMemInfo(dev, "VRAM")[0] / 1024 / 1024 + + def measure_gpu_usage(self): + if self.rocm_smi is None: + return None + + device_count = len(self.rocm_smi.listDevices()) if self.rocm_smi is not None else 0 + max_gpu_usage = [0 for i in range(device_count)] + gpu_name = [f"GPU{i}" for i in range(device_count)] + while True: + for i in range(device_count): + max_gpu_usage[i] = max(max_gpu_usage[i], self.get_used_memory(i)) + time.sleep(0.005) # 5ms + if not self.keep_measuring: + break + return [ + { + "device_id": i, + "name": gpu_name[i], + "max_used_MB": max_gpu_usage[i], + } + for i in range(device_count) + ] + + +def measure_memory(is_gpu, func, monitor_type="cuda", start_memory=None): + memory_monitor_type = None + if monitor_type == "rocm": + memory_monitor_type = RocmMemoryMonitor + else: + memory_monitor_type = CudaMemoryMonitor + + monitor = memory_monitor_type(False) if is_gpu: - memory_before_test = monitor.measure_gpu_usage() + if start_memory is not None: + memory_before_test = start_memory + else: + memory_before_test = monitor.measure_gpu_usage() if memory_before_test is None: return None + if func is None: + return memory_before_test + with ThreadPoolExecutor() as executor: - monitor = MemoryMonitor() + monitor = memory_monitor_type() mem_thread = executor.submit(monitor.measure_gpu_usage) try: fn_thread = executor.submit(func) @@ -533,10 +602,16 @@ def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]: return None # CPU memory - memory_before_test = monitor.measure_cpu_usage() + if start_memory is not None: + memory_before_test = start_memory + else: + memory_before_test = monitor.measure_cpu_usage() + + if func is None: + return memory_before_test with ThreadPoolExecutor() as executor: - monitor = MemoryMonitor() + monitor = memory_monitor_type() mem_thread = executor.submit(monitor.measure_cpu_usage) try: fn_thread = executor.submit(func) diff --git a/onnxruntime/python/tools/transformers/compare_bert_results.py b/onnxruntime/python/tools/transformers/compare_bert_results.py index 4cb9585962..61e4c97c75 100644 --- a/onnxruntime/python/tools/transformers/compare_bert_results.py +++ b/onnxruntime/python/tools/transformers/compare_bert_results.py @@ -33,18 +33,20 @@ def run_model(model_path, all_inputs, use_gpu, disable_optimization): return results, latency_list, output_names -def compare(baseline_results, treatment_results, verbose, rtol=1e-3, atol=1e-4): +def compare(baseline_results, treatment_results, verbose, rtol=1e-1, atol=1e-3): # Validate the output of baseline and treatment, to make sure the results are similar. diff_count = 0 - max_rel_diff = 0 max_abs_diff = 0 for test_case_id, results in enumerate(baseline_results): case_passed = True for i in range(len(results)): treatment_output = treatment_results[test_case_id][i] - rel_diff = np.amax(np.abs((treatment_output - results[i]) / results[i])) abs_diff = np.amax(np.abs(treatment_output - results[i])) - max_rel_diff = max(max_rel_diff, rel_diff) + if verbose and abs_diff > atol: + print("abs_diff", abs_diff) + print("treatment", treatment_output) + print("baseline", results[i]) + max_abs_diff = max(max_abs_diff, abs_diff) if not np.allclose(results[i].tolist(), treatment_output.tolist(), rtol=rtol, atol=atol): if case_passed: @@ -54,7 +56,7 @@ def compare(baseline_results, treatment_results, verbose, rtol=1e-3, atol=1e-4): if verbose: print(f"case {test_case_id} output {i}") print(f"baseline={results[i].tolist()}\ntreatment={treatment_output}") - print(f"rel_diff={rel_diff} abs_diff={abs_diff}") + print(f"abs_diff={abs_diff}") if diff_count == 0: print( @@ -70,8 +72,7 @@ def compare(baseline_results, treatment_results, verbose, rtol=1e-3, atol=1e-4): ) print(f"maximum absolute difference={max_abs_diff}") - - print(f"maximum relative difference={max_rel_diff}") + return max_abs_diff, case_passed def run_test( @@ -133,7 +134,7 @@ def run_test( print(f"treatment average latency: {statistics.mean(treatment_latency) * 1000} ms") # Validate the output of baseline and treatment, to make sure the results are similar. - compare(baseline_results, treatment_results, verbose, rtol, atol) + return compare(baseline_results, treatment_results, verbose, rtol, atol) def parse_arguments(): diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 73561d312e..b59af41c49 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -883,7 +883,8 @@ def remove_shared_initializers( graph2: GraphProto, shared_prefix: str = "shared_", min_elements: int = 1024, - require_raw_data: bool = False, + signature_cache1: Optional[dict] = None, + signature_cache2: Optional[dict] = None, ): """Remove initializers with same value from two graphs. @@ -892,7 +893,8 @@ def remove_shared_initializers( graph2 (GraphProto): the second graph to process shared_prefix (str): add prefix to the shared initializers among two graphs min_elements (int, optional): minimal number of elements for initializers to be considered. Defaults to 1024. - require_raw_data (bool, optional): Only remove tensors with raw_data field to speed up method + signature_cache1 (dict): Optional dictionary to store data signatures of tensors in graph1 in order to speed up comparison + signature_cache2 (dict): Optional dictionary to store data signatures of tensors in graph2 in order to speed up comparison """ mapping_initializers_1 = {} @@ -909,7 +911,7 @@ def remove_shared_initializers( if not (initializer2.dims and sum(initializer2.dims) >= min_elements): continue - if OnnxModel.has_same_value(initializer1, initializer2, require_raw_data=True): + if OnnxModel.has_same_value(initializer1, initializer2, signature_cache1, signature_cache2): mapping_initializers_1[initializer1.name] = shared_prefix + initializer2.name shared_initializers_1.append(initializer1) @@ -982,14 +984,21 @@ def remove_shared_initializers( return shared_initializers_2 -def get_shared_initializers(encoder_model: ModelProto, decoder_model: ModelProto, require_raw_data: bool = False): +def get_shared_initializers(encoder_model: ModelProto, decoder_model: ModelProto): encoder = OnnxModel(encoder_model) decoder = OnnxModel(decoder_model) encoder.add_prefix_to_names("e_") decoder.add_prefix_to_names("d_") - encoder.remove_duplicated_initializer(require_raw_data) - decoder.remove_duplicated_initializer(require_raw_data) - initializers = remove_shared_initializers(decoder.model.graph, encoder.model.graph, "s_", require_raw_data) + signature_cache1, signature_cache2 = {}, {} + encoder.remove_duplicated_initializer(signature_cache1) + decoder.remove_duplicated_initializer(signature_cache2) + initializers = remove_shared_initializers( + decoder.model.graph, + encoder.model.graph, + shared_prefix="s_", + signature_cache1=signature_cache1, + signature_cache2=signature_cache2, + ) return initializers @@ -1263,7 +1272,139 @@ def find_past_seq_len_usage(subg: GraphProto): return tensor_names_to_rename, nodes_to_remove -def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphProto): +def replace_mha_with_gqa(model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1): + # Insert attention_mask subgraph to calculate shared inputs for all GroupQueryAttention nodes + # + # attention_mask + # / \ + # ReduceSum Shape + # | | + # Sub Gather + # | | + # seqlens_k total_sequence_length + # | | + # Cast to int32 Cast to int32 + + model.add_initializer( + onnx.helper.make_tensor( + name="one", + data_type=TensorProto.INT64, + dims=[1], + vals=[1], + ) + ) + reduce_sum_node = onnx.helper.make_node( + "ReduceSum", + inputs=[attn_mask, "one"], + outputs=[attn_mask + "_row_sums"], + name=model.create_node_name("ReduceSum"), + ) + sub_node = onnx.helper.make_node( + "Sub", + inputs=[attn_mask + "_row_sums", "one"], + outputs=["seqlens_k_int64"], + name=model.create_node_name("Sub"), + ) + seqlen_k_cast_node = onnx.helper.make_node( + "Cast", + inputs=["seqlens_k_int64"], + outputs=["seqlens_k"], + name=model.create_node_name("Cast"), + to=TensorProto.INT32, + ) + shape_node = onnx.helper.make_node( + "Shape", + inputs=[attn_mask], + outputs=[attn_mask + "_shape"], + name=model.create_node_name("Shape"), + ) + gather_node = onnx.helper.make_node( + "Gather", + inputs=[attn_mask + "_shape", "one"], + outputs=["total_seq_len_int64"], + name=model.create_node_name("Gather"), + axis=0, + ) + total_seqlen_cast_node = onnx.helper.make_node( + "Cast", + inputs=["total_seq_len_int64"], + outputs=["total_seq_len"], + name=model.create_node_name("Cast"), + to=TensorProto.INT32, + ) + model.model.graph.node.extend( + [reduce_sum_node, sub_node, seqlen_k_cast_node, shape_node, gather_node, total_seqlen_cast_node] + ) + + # Replace MultiHeadAttention with GroupQueryAttention + mha_nodes = list(filter(lambda node: node.op_type == "MultiHeadAttention", model.model.graph.node)) + for node in mha_nodes: + num_heads_mha = 0 + for att in node.attribute: + if att.name == "num_heads": + num_heads_mha = att.i + gqa_node = onnx.helper.make_node( + "GroupQueryAttention", + inputs=[ + node.input[0], # query + node.input[1], # key + node.input[2], # value + node.input[6], # past_key + node.input[7], # past_value + "seqlens_k", # seqlens_k (for attention_mask) + "total_seq_len", # total_seq_len (for attention_mask) + ], + outputs=node.output, + name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"), + domain="com.microsoft", + num_heads=num_heads_mha // world_size, + kv_num_heads=num_heads_mha // world_size if kv_num_heads == 0 else kv_num_heads // world_size, + ) + model.model.graph.node.remove(node) + model.model.graph.node.extend([gqa_node]) + return model + + +def update_decoder_subgraph_output_cross_attention(subg: GraphProto): + input_self_past_0 = 1 + # w/wo attention mask, w/wo hidden_state + graph_input_names = [gi.name for gi in subg.input] + while input_self_past_0 < 3 and not graph_input_names[input_self_past_0].startswith("past"): + input_self_past_0 += 1 + output_self_present_0 = 1 + + num_layers = (len(subg.output) - output_self_present_0) // 2 + input_cross_past_0 = 2 * num_layers + input_self_past_0 + past_key_cross_inputs = {subg.input[layer * 2 + input_cross_past_0].name: layer for layer in range(num_layers)} + print(f" --past_key_cross_inputs={past_key_cross_inputs}") + + input_past_key_cross_0_shape = shape_of(subg.input[input_cross_past_0]) + print(f"past_key_cross_0_shape is {input_past_key_cross_0_shape}") + batch_size_dim = input_past_key_cross_0_shape[0] + num_heads_dim = input_past_key_cross_0_shape[1] + cross_seq_len_dim = input_past_key_cross_0_shape[2] + + num_layer_output_qk = 0 + for node in subg.node: + if (node.op_type == "DecoderMaskedMultiHeadAttention") and (node.input[1] in past_key_cross_inputs): + print(f" -- add cross QK output from: node: {node.name} with output: {node.output}") + num_layer_output_qk += 1 + layer = past_key_cross_inputs[node.input[1]] + cross_attention_out_name = f"output_cross_qk_{layer}" + appended_names = [""] * (3 - len(node.output)) + appended_names.append(cross_attention_out_name) + node.output.extend(appended_names) + node.attribute.extend([onnx.helper.make_attribute("output_qk", 1)]) + + cross_attention = onnx.helper.make_tensor_value_info( + cross_attention_out_name, TensorProto.FLOAT, [batch_size_dim, num_heads_dim, 1, cross_seq_len_dim] + ) + subg.output.extend([cross_attention]) + if num_layer_output_qk != num_layers: + raise ValueError(f"Did not add cross QK for all layers{num_layers} vs {num_layer_output_qk}") + + +def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: ModelProto): input_self_past_0 = 1 # w/wo attention mask, w/wo hidden_state graph_input_names = [gi.name for gi in subg.input] diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py index 02a260b784..95e7437493 100644 --- a/onnxruntime/python/tools/transformers/float16.py +++ b/onnxruntime/python/tools/transformers/float16.py @@ -20,8 +20,7 @@ import numpy as np import onnx -from onnx import helper, numpy_helper -from onnx import onnx_pb as onnx_proto +from onnx import AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto, helper, numpy_helper from onnx.shape_inference import infer_shapes, infer_shapes_path from packaging import version @@ -87,11 +86,11 @@ def convert_tensor_float_to_float16(tensor, min_positive_val=5.96e-08, max_finit TensorProto: the converted tensor. """ - if not isinstance(tensor, onnx_proto.TensorProto): + if not isinstance(tensor, TensorProto): raise ValueError(f"Expected input type is an ONNX TensorProto but got {type(tensor)}") - if tensor.data_type == onnx_proto.TensorProto.FLOAT: - tensor.data_type = onnx_proto.TensorProto.FLOAT16 + if tensor.data_type == TensorProto.FLOAT: + tensor.data_type = TensorProto.FLOAT16 # convert float_data (float type) to float16 and write to int32_data if tensor.float_data: float16_data = convert_np_to_float16(np.array(tensor.float_data), min_positive_val, max_finite_val) @@ -146,18 +145,19 @@ def make_value_info_from_tensor(tensor): # Some operators has data type fixed as float for some inputs. Key is op_type, value is list of input indices -ALWAYS_FLOAT_INPUTS = {"Resize": [2], "GroupNorm": [1, 2]} +# Note that DirectML allows float16 gamma and beta in GroupNorm. Use force_fp16_inputs parameter could overwrite this. +ALWAYS_FLOAT_INPUTS = {"Resize": [2], "GroupNorm": [1, 2], "SkipGroupNorm": [1, 2]} class InitializerTracker: """Class for keeping track of initializer.""" - def __init__(self, initializer: onnx_proto.TensorProto): + def __init__(self, initializer: TensorProto): self.initializer = initializer self.fp32_nodes = [] self.fp16_nodes = [] - def add_node(self, node: onnx_proto.NodeProto, is_node_blocked): + def add_node(self, node: NodeProto, is_node_blocked): if is_node_blocked: self.fp32_nodes.append(node) else: @@ -219,7 +219,7 @@ def convert_float_to_float16( else: model = onnx.load(model_path) - if not isinstance(model, onnx_proto.ModelProto): + if not isinstance(model, ModelProto): raise ValueError(f"Expected an ONNX ModelProto but got {type(model)}") func_infer_shape = None @@ -259,8 +259,8 @@ def convert_float_to_float16( graph_io_to_skip = set() io_casts = set() - fp32_inputs = [n.name for n in model.graph.input if n.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT] - fp32_outputs = [n.name for n in model.graph.output if n.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT] + fp32_inputs = [n.name for n in model.graph.input if n.type.tensor_type.elem_type == TensorProto.FLOAT] + fp32_outputs = [n.name for n in model.graph.output if n.type.tensor_type.elem_type == TensorProto.FLOAT] if isinstance(keep_io_types, list): fp32_inputs = [n for n in fp32_inputs if n in keep_io_types] fp32_outputs = [n for n in fp32_outputs if n in keep_io_types] @@ -278,9 +278,9 @@ def convert_float_to_float16( new_value_info = model.graph.value_info.add() new_value_info.CopyFrom(n) new_value_info.name = output_name - new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 + new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT16 # add Cast node (from tensor(float) to tensor(float16) after graph input - new_node = [helper.make_node("Cast", [n.name], [output_name], to=10, name=node_name)] + new_node = [helper.make_node("Cast", [n.name], [output_name], to=TensorProto.FLOAT16, name=node_name)] model.graph.node.extend(new_node) value_info_list.append(new_value_info) io_casts.add(node_name) @@ -296,7 +296,7 @@ def convert_float_to_float16( new_value_info = model.graph.value_info.add() new_value_info.CopyFrom(n) new_value_info.name = input_name - new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 + new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT16 new_node = [helper.make_node("Cast", [input_name], [n.name], to=1, name=node_name)] model.graph.node.extend(new_node) value_info_list.append(new_value_info) @@ -307,12 +307,12 @@ def convert_float_to_float16( next_level = [] for q in queue: # if q is model, push q.graph (GraphProto) - if isinstance(q, onnx_proto.ModelProto): + if isinstance(q, ModelProto): next_level.append(q.graph) # if q is model.graph, push q.node.attribute (AttributeProto) - if isinstance(q, onnx_proto.GraphProto): + if isinstance(q, GraphProto): for n in q.initializer: # TensorProto type - if n.data_type == onnx_proto.TensorProto.FLOAT: + if n.data_type == TensorProto.FLOAT: assert n.name not in fp32_initializers fp32_initializers[n.name] = InitializerTracker(n) @@ -343,10 +343,32 @@ def convert_float_to_float16( else: if n.op_type == "Cast": for attr in n.attribute: - if attr.name == "to" and attr.i == 1: - attr.i = 10 + if attr.name == "to" and attr.i == TensorProto.FLOAT: + attr.i = TensorProto.FLOAT16 break + if n.op_type in [ + "EyeLike", + "Multinomial", + "RandomNormal", + "RandomNormalLike", + "RandomUniform", + "RandomUniformLike", + "SequenceEmpty", + "Bernoulli", + ]: + has_dtype = False + for attr in n.attribute: + if attr.name == "dtype": + has_dtype = True + if attr.i == TensorProto.FLOAT: + attr.i = TensorProto.FLOAT16 + + # The dtype attribute is optional and default is FLOAT in the following operators + # so we need add dtype attribute to specify the data type float16 + if (n.op_type in ["RandomNormal", "RandomUniform", "SequenceEmpty"]) and not has_dtype: + n.attribute.extend([helper.make_attribute("dtype", TensorProto.FLOAT16)]) + # For Resize/GroupNorm, attribute data type cannot be changed if n.op_type not in ALWAYS_FLOAT_INPUTS or n.op_type in force_fp16_inputs_dict: for attr in n.attribute: @@ -356,7 +378,7 @@ def convert_float_to_float16( # if q is model.graph.node.attribute, push q.g and q.graphs (GraphProto) # and process node.attribute.t and node.attribute.tensors (TensorProto) - if isinstance(q, onnx_proto.AttributeProto): + if isinstance(q, AttributeProto): next_level.append(q.g) for n in q.graphs: next_level.append(n) # noqa: PERF402 @@ -364,19 +386,19 @@ def convert_float_to_float16( for n in q.tensors: n = convert_tensor_float_to_float16(n, min_positive_val, max_finite_val) # noqa: PLW2901 # if q is graph, process input, output and value_info (ValueInfoProto) - if isinstance(q, onnx_proto.GraphProto): + if isinstance(q, GraphProto): # Note that float initializers tracked by fp32_initializers will be processed later. # for all ValueInfoProto with tensor(float) type in input, output and value_info, convert them to # tensor(float16) except map and seq(map). And save them in value_info_list for further processing for n in itertools.chain(q.input, q.output, q.value_info): - if n.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT: + if n.type.tensor_type.elem_type == TensorProto.FLOAT: if n.name not in graph_io_to_skip: - n.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 + n.type.tensor_type.elem_type = TensorProto.FLOAT16 value_info_list.append(n) if n.type.HasField("sequence_type"): - if n.type.sequence_type.elem_type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT: + if n.type.sequence_type.elem_type.tensor_type.elem_type == TensorProto.FLOAT: if n.name not in graph_io_to_skip: - n.type.sequence_type.elem_type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 + n.type.sequence_type.elem_type.tensor_type.elem_type = TensorProto.FLOAT16 value_info_list.append(n) queue = next_level @@ -405,7 +427,7 @@ def convert_float_to_float16( new_value_info.CopyFrom(value_info) output_name = node.name + "_input_cast_" + str(i) new_value_info.name = output_name - new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT + new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT # add Cast node (from tensor(float16) to tensor(float) before current node node_name = node.name + "_input_cast" + str(i) new_node = [helper.make_node("Cast", [input_name], [output_name], to=1, name=node_name)] @@ -428,7 +450,7 @@ def convert_float_to_float16( new_value_info.CopyFrom(value_info) output_name = node.name + "_input_cast_" + str(i) new_value_info.name = output_name - new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT + new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT # add Cast node (from tensor(float16) to tensor(float) before current node node_name = node.name + "_input_cast" + str(i) new_node = [helper.make_node("Cast", [input_name], [output_name], to=1, name=node_name)] @@ -447,7 +469,7 @@ def convert_float_to_float16( new_value_info.CopyFrom(value_info) input_name = node.name + "_output_cast_" + str(i) new_value_info.name = input_name - new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT + new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT # add Cast node (from tensor(float) to tensor(float16) after current node node_name = node.name + "_output_cast" + str(i) new_node = [helper.make_node("Cast", [input_name], [output], to=10, name=node_name)] @@ -460,9 +482,9 @@ def convert_float_to_float16( def float_to_float16_max_diff(tensor, min_positive_val=5.96e-08, max_finite_val=65504.0): """Measure the maximum absolute difference after converting a float tensor to float16.""" - if not isinstance(tensor, onnx_proto.TensorProto): + if not isinstance(tensor, TensorProto): raise ValueError(f"Expected input type is an ONNX TensorProto but got {type(tensor)}") - if tensor.data_type != onnx_proto.TensorProto.FLOAT: + if tensor.data_type != TensorProto.FLOAT: raise ValueError("Expected tensor data type is float.") float32_data = None diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 31496c5052..edaf78edb2 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -84,6 +84,7 @@ def process_mask(self, input: str) -> str: data_type=TensorProto.INT64, dims=[1], vals=[1], + raw=False, ) ) mask_index_node = helper.make_node( @@ -110,7 +111,7 @@ def __init__( model: OnnxModel, hidden_size: int, num_heads: int, - attention_mask: AttentionMask, + attention_mask: Optional[AttentionMask] = None, use_multi_head_attention: bool = False, disable_multi_head_attention_bias: bool = False, search_op_types: List[str] = ["SkipLayerNormalization", "LayerNormalization"], # noqa: B006 @@ -119,7 +120,7 @@ def __init__( super().__init__(model, attention_op_name, search_op_types) self.hidden_size = hidden_size self.num_heads = num_heads - self.attention_mask = attention_mask + self.attention_mask = attention_mask if attention_mask else AttentionMask(model) self.use_multi_head_attention = use_multi_head_attention self.disable_multi_head_attention_bias = disable_multi_head_attention_bias self.mask_filter_value = None @@ -203,7 +204,7 @@ def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int] def get_add_qk_str(self, add_qk: NodeProto): shape_infer = self.model.infer_runtime_shape(update=True) if shape_infer is None: - return + return None input_0_shape = shape_infer.get_edge_shape(add_qk.input[0]) input_1_shape = shape_infer.get_edge_shape(add_qk.input[1]) @@ -218,6 +219,31 @@ def get_add_qk_str(self, add_qk: NodeProto): return add_qk.input[1] + def reshape_add_qk(self, add_qk: str): + # Convert 4D mask from (B,1,S,T) to (B,N,S,T) + # B = batch size, N = num heads, S = source sequence length, T = target sequence length + mask_output_name = add_qk + "_mask" + + # Check if concat node for (B,1,S,T) --> (B,N,S,T) already exists + concat_node = list(filter(lambda node: node.output[0] == mask_output_name, self.nodes_to_add)) + if len(concat_node) == 1: + return mask_output_name + + assert len(concat_node) == 0 + concat_node_name = self.model.create_node_name("Concat") + concat_add_qk_fp32 = helper.make_node( + "Concat", + inputs=[add_qk for _ in range(self.num_heads)], + outputs=[mask_output_name], + name=concat_node_name, + axis=1, + ) + # Add new node to graph + self.nodes_to_add.append(concat_add_qk_fp32) + self.node_name_to_graph_name[concat_node_name] = self.this_graph_name + + return mask_output_name + def concat_kv(self, past_k: str, past_v: str) -> str: """Concatenate past_k and past_v inputs to create past_kv input. @@ -428,19 +454,12 @@ def create_combined_qkv_bias( qkv_bias_dim = 3 * np.prod(qb.shape) bias_name = name_prefix + "_qkv_bias" - bias = helper.make_tensor( + self.add_initializer( name=bias_name, - data_type=TensorProto.FLOAT, + data_type=q_bias.data_type, dims=[qkv_bias_dim], - vals=qkv_bias.flatten().tolist(), + vals=qkv_bias, ) - - # Convert bias to FP16 if model is using FP16 - if q_bias.data_type == 10: - bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name)) - - self.model.add_initializer(bias, self.this_graph_name) - return bias_name def create_packed_qkv_matmul_node( @@ -488,13 +507,13 @@ def create_packed_qkv_matmul_node( qkv_weight = np.stack((qw, kw, vw), axis=1).reshape((d, 3 * d)) qkv_weight_name = matmul_node_name + "_qkv_weight" - weight = helper.make_tensor( + + self.add_initializer( name=qkv_weight_name, - data_type=TensorProto.FLOAT, + data_type=q_weight.data_type, dims=[qkv_weight.shape[0], qkv_weight.shape[1]], - vals=qkv_weight.flatten().tolist(), + vals=qkv_weight, ) - self.model.add_initializer(weight, self.this_graph_name) # Created packed QKV MatMul with output (B, S, 3*D) # Output is of the form: @@ -519,23 +538,15 @@ def create_packed_qkv_matmul_node( # Create Slice nodes to access Q, K, V q_slice_name = matmul_node_name + "_q_start_index" - q_start_tensor = helper.make_tensor(name=q_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[0]) + self.add_initializer(name=q_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[0], raw=False) k_slice_name = matmul_node_name + "_k_start_index" - k_start_tensor = helper.make_tensor(name=k_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[d]) + self.add_initializer(name=k_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[d], raw=False) v_slice_name = matmul_node_name + "_v_start_index" - v_start_tensor = helper.make_tensor(name=v_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[2 * d]) + self.add_initializer(name=v_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[2 * d], raw=False) end_of_qkv_name = matmul_node_name + "_end_of_qkv_index" - end_of_qkv_tensor = helper.make_tensor( - name=end_of_qkv_name, data_type=TensorProto.INT64, dims=[1], vals=[3 * d] - ) + self.add_initializer(name=end_of_qkv_name, data_type=TensorProto.INT64, dims=[1], vals=[3 * d], raw=False) qkv_last_axis_name = matmul_node_name + "_qkv_last_axis" - qkv_axis_tensor = helper.make_tensor(name=qkv_last_axis_name, data_type=TensorProto.INT64, dims=[1], vals=[-1]) - - self.model.add_initializer(q_start_tensor, self.this_graph_name) - self.model.add_initializer(k_start_tensor, self.this_graph_name) - self.model.add_initializer(v_start_tensor, self.this_graph_name) - self.model.add_initializer(end_of_qkv_tensor, self.this_graph_name) - self.model.add_initializer(qkv_axis_tensor, self.this_graph_name) + self.add_initializer(name=qkv_last_axis_name, data_type=TensorProto.INT64, dims=[1], vals=[-1], raw=False) q_slice_output = matmul_node_name + "_q_out" q_slice = helper.make_node( @@ -719,6 +730,7 @@ def create_attention_node( present_k: str = "", present_v: str = "", scale: Optional[float] = None, + causal: bool = False, ) -> Union[NodeProto, None]: """Create an Attention node. @@ -739,6 +751,8 @@ def create_attention_node( past_v (str): name of input for past V value present_k (str): name of output to store present K value present_v (str): name of output to store present V value + scale: scale before softmax + causal: whether it is uni-directional mask. Returns: Union[NodeProto, None]: the node created or None if failed. @@ -823,7 +837,6 @@ def create_attention_node( assert q_bias_shape == k_bias_shape == qw_out_size assert v_bias_shape == vw_out_size - qkv_bias_dim = 0 if is_qkv_diff_dims: qkv_bias = np.concatenate((qb, kb, vb), axis=0) qkv_bias_dim = q_bias_shape + k_bias_shape + v_bias_shape @@ -834,33 +847,24 @@ def create_attention_node( attention_node_name = self.model.create_node_name("Attention") if not self.use_multi_head_attention: - weight = helper.make_tensor( + self.add_initializer( name=attention_node_name + "_qkv_weight", - data_type=TensorProto.FLOAT, + data_type=q_weight.data_type, dims=[qw_in_size, qkv_weight_dim], - vals=qkv_weight.flatten().tolist(), + vals=qkv_weight, ) - # Sometimes weights and bias are stored in fp16 - if q_weight.data_type == 10: - weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name)) - self.model.add_initializer(weight, self.this_graph_name) - - bias = None if has_bias: - bias = helper.make_tensor( + self.add_initializer( name=attention_node_name + "_qkv_bias", - data_type=TensorProto.FLOAT, + data_type=q_bias.data_type, dims=[qkv_bias_dim], - vals=qkv_bias.flatten().tolist(), + vals=qkv_bias, ) - if q_bias.data_type == 10: - bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name)) - self.model.add_initializer(bias, self.this_graph_name) # For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights. if self.use_multi_head_attention: - if add_qk_str is not None: + if add_qk_str: logger.debug("MultiHeadAttention does not support relative_position_bias: cannot fuse the attention.") return None @@ -897,20 +901,7 @@ def create_attention_node( attention_inputs.append(past_kv) if add_qk_str is not None: - # Convert 4d mask from (B,1,M,M) to (B,N,M,M) - # B = batch size, M = max sequence length, N = num heads - concat_node_name = self.model.create_node_name("Concat") - mask_output_name = add_qk_str + "_mask" - concat_add_qk_fp32 = helper.make_node( - "Concat", - inputs=[add_qk_str for _ in range(num_heads)], - outputs=[mask_output_name], - name=concat_node_name, - axis=1, - ) - # Add new nodes to graph - self.nodes_to_add.append(concat_add_qk_fp32) - self.node_name_to_graph_name[concat_node_name] = self.this_graph_name + mask_output_name = self.reshape_add_qk(add_qk_str) # Add attention mask to attention node if not past_exists: @@ -933,6 +924,9 @@ def create_attention_node( attention_node.domain = "com.microsoft" attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) + if causal: + attention_node.attribute.extend([helper.make_attribute("unidirectional", 1)]) + if scale is not None: attention_node.attribute.extend([helper.make_attribute("scale", scale)]) @@ -1166,6 +1160,13 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): attention_last_node = reshape_qkv if einsum_node is None else transpose_qkv q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q) + if q_num_heads <= 0 or q_hidden_size <= 0: + logger.warning( + "Failed to detect num_heads and hidden_size for Attention fusion. " + "Please specify those parameters in argument." + ) + return + # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads # the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately new_node = self.create_attention_node( @@ -1191,14 +1192,15 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if einsum_node is not None: unique_index = einsum_node.input[0] new_edge = "edge_modified_" + unique_index - shape_tensor = helper.make_tensor( + + shape_tensor = self.add_initializer( name="shape_modified_tensor" + unique_index, data_type=TensorProto.INT64, dims=[4], - vals=np.int64([0, 0, q_num_heads, int(q_hidden_size / q_num_heads)]).tobytes(), - raw=True, + vals=np.int64([0, 0, q_num_heads, int(q_hidden_size / q_num_heads)]), + raw=False, ) - self.model.add_initializer(shape_tensor, self.this_graph_name) + self.model.add_node( helper.make_node( "Reshape", diff --git a/onnxruntime/python/tools/transformers/fusion_attention_clip.py b/onnxruntime/python/tools/transformers/fusion_attention_clip.py new file mode 100644 index 0000000000..d400e248d6 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_attention_clip.py @@ -0,0 +1,218 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from logging import getLogger +from typing import Tuple + +from fusion_attention import AttentionMask, FusionAttention +from fusion_options import AttentionMaskFormat +from onnx import NodeProto +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionAttentionClip(FusionAttention): + """ + Fuse Attention subgraph of Clip into one Attention node. + """ + + def __init__( + self, + model: OnnxModel, + hidden_size: int, + num_heads: int, + ): + attention_mask = AttentionMask(model) + attention_mask.mask_format = AttentionMaskFormat.NoMask + + super().__init__( + model, + hidden_size, + num_heads, + attention_mask, + use_multi_head_attention=False, + search_op_types=["SkipLayerNormalization"], + ) + + def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int]: + """Detect num_heads and hidden_size for ONNX model from MiDaS + Args: + reshape_q (NodeProto): reshape node for q + Returns: + Tuple[int, int]: num_heads and hidden_size + """ + concat = self.model.match_parent(reshape_q, "Concat", 1) + if concat is None or len(concat.input) != 4: + return self.num_heads, self.hidden_size + + # The shape is a tensor like [?, ?, num_heads, head_size] + num_head_value = self.model.get_constant_value(concat.input[2]) + if num_head_value is None: + return self.num_heads, self.hidden_size # Fall back to user specified value + + if len(num_head_value) != 1 or num_head_value[0] <= 0: + return self.num_heads, self.hidden_size # Fall back to user specified value + + num_heads = num_head_value[0] + + head_size_value = self.model.get_constant_value(concat.input[3]) + if head_size_value is None: + return self.num_heads, self.hidden_size # Fall back to user specified value + + if len(head_size_value) != 1 or head_size_value[0] <= 0: + return self.num_heads, self.hidden_size # Fall back to user specified value + + head_size = head_size_value[0] + + hidden_size = num_heads * head_size + + if self.num_heads > 0 and num_heads != self.num_heads: + if self.num_heads_warning: + logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.") + self.num_heads_warning = False # Do not show the warning more than once + + if self.hidden_size > 0 and hidden_size != self.hidden_size: + if self.hidden_size_warning: + logger.warning( + f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value." + ) + self.hidden_size_warning = False # Do not show the warning more than once + + return num_heads, hidden_size + + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + skip_input_index = None + node_before_layer_norm = None + for i in [1, 0]: + parent = self.model.match_parent(normalize_node, "SkipLayerNormalization", i) + if parent is not None: + skip_input_index = i + node_before_layer_norm = parent + + root_input = None + if node_before_layer_norm is not None: + root_input = node_before_layer_norm.output[0] + else: + # Deal with the first attention after the embedding layer. + for i in [0, 1]: + node_before_layer_norm = self.model.match_parent(normalize_node, "Add", i) + if node_before_layer_norm is None: + continue + child = self.model.find_first_child_by_type( + node_before_layer_norm, "LayerNormalization", input_name_to_nodes, False + ) + if child is None: + continue + root_input = child.output[0] + skip_input_index = i + break + + if skip_input_index is None: + return + + qkv_nodes = self.model.match_parent_path( + normalize_node, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [1 - skip_input_index, None, None, 0, 0, 0], + ) + if qkv_nodes is None: + return + + (_, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes + + v_nodes = self.model.match_parent_path( + matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, None] + ) + if v_nodes is None: + logger.debug("fuse_attention: failed to match v path") + return + (_, _, reshape_v, add_v, matmul_v) = v_nodes + + add_mask_indices = [] + qk_nodes = self.model.match_parent_path( + matmul_qkv, + ["Softmax", "Reshape", "Add", "Reshape", "MatMul"], + [0, 0, 0, None, 0], + return_indice=add_mask_indices, + ) + if qk_nodes is None: + logger.debug("fuse_attention: failed to match qk path") + return + assert len(add_mask_indices) == 1 + causal_mask_input_index = 1 - add_mask_indices[0] + + (_softmax_qk, _, add_mask, _, matmul_qk) = qk_nodes + + q_nodes = self.model.match_parent_path( + matmul_qk, ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], [0, 0, 0, 0, None, None] + ) + if q_nodes is None: + logger.debug("fuse_attention: failed to match q path") + return + (_, _transpose_q, reshape_q, mul_q, add_q, matmul_q) = q_nodes + + k_nodes = self.model.match_parent_path( + matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, 0, None] + ) + if k_nodes is None: + logger.debug("fuse_attention: failed to match k path") + return + + (_transpose_k, _reshape_k, _, _, add_k, matmul_k) = k_nodes + if matmul_q.input[0] != root_input or matmul_k.input[0] != root_input or matmul_v.input[0] != root_input: + logger.debug("fuse_attention: expect to have same input to q, k and v matmul") + return + + num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q) + if num_heads <= 0 or hidden_size <= 0: + logger.debug("fuse_attention: failed to detect num_heads or hidden_size") + return + + attention_last_node = reshape_qkv + + # Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path + # of computing causal mask. + causal_mask_nodes = self.model.match_parent_path( + add_mask, + ["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], + [causal_mask_input_index, 0, 0, 0, 0, 0], + ) + if causal_mask_nodes is None: + # If the model is exported with batch_size == 1, there is no Concat node + causal_mask_nodes = self.model.match_parent_path( + add_mask, + ["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], + [causal_mask_input_index, 0, 0, 0, 0], + ) + if causal_mask_nodes is None: + logger.debug("fuse_attention: failed to match causal mask subgraph") + return + + new_node = self.create_attention_node( + mask_index=None, + q_matmul=matmul_q, + k_matmul=matmul_k, + v_matmul=matmul_v, + q_add=add_q, + k_add=add_k, + v_add=add_v, + num_heads=num_heads, + hidden_size=hidden_size, + input=root_input, + output=attention_last_node.output[0], + add_qk_str=None, + scale=None, + causal=True, + ) + if new_node is None: + return + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend([attention_last_node, transpose_qkv]) + + # Use prune graph to remove nodes since they are shared by all attention nodes. + self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/fusion_attention_unet.py b/onnxruntime/python/tools/transformers/fusion_attention_unet.py index f286206e5b..250ec5f3eb 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_unet.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_unet.py @@ -210,15 +210,13 @@ def create_attention_node( ) matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV") - weight = helper.make_tensor( + self.add_initializer( name=matmul_node_name + "_weight", data_type=TensorProto.FLOAT, dims=[qkv_weight.shape[0], qkv_weight.shape[1]], - vals=qkv_weight.flatten().tolist(), + vals=qkv_weight, ) - self.model.add_initializer(weight, self.this_graph_name) - matmul_node = helper.make_node( "MatMul", inputs=[k_matmul.input[0], matmul_node_name + "_weight"], @@ -227,13 +225,13 @@ def create_attention_node( ) self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name - shape_tensor = helper.make_tensor( + self.add_initializer( name=matmul_node_name + "_reshape_shape", data_type=TensorProto.INT64, dims=[5], vals=[0, 0, n, 3, h], + raw=False, ) - self.model.add_initializer(shape_tensor, self.this_graph_name) reshape_node = helper.make_node( "Reshape", @@ -251,14 +249,12 @@ def create_attention_node( attention_node_name = self.model.create_node_name("Attention") - weight = helper.make_tensor( + self.add_initializer( name=attention_node_name + "_qkv_weight", data_type=TensorProto.FLOAT, dims=[qw_in_size, qkv_weight_dim], - vals=qkv_weight.flatten().tolist(), + vals=qkv_weight, ) - - self.model.add_initializer(weight, self.this_graph_name) else: # cross attention attention_node_name = self.model.create_node_name("MultiHeadAttention") if self.enable_packed_kv: @@ -282,15 +278,13 @@ def create_attention_node( kv_weight = np.dstack([kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(c, n * 2 * h) matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV") - weight = helper.make_tensor( + self.add_initializer( name=matmul_node_name + "_weight", data_type=TensorProto.FLOAT, dims=[kv_weight.shape[0], kv_weight.shape[1]], - vals=kv_weight.flatten().tolist(), + vals=kv_weight, ) - self.model.add_initializer(weight, self.this_graph_name) - matmul_node = helper.make_node( "MatMul", inputs=[k_matmul.input[0], matmul_node_name + "_weight"], @@ -299,13 +293,13 @@ def create_attention_node( ) self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name - shape_tensor = helper.make_tensor( + self.add_initializer( name=matmul_node_name + "_reshape_shape", data_type=TensorProto.INT64, dims=[5], vals=[0, 0, n, 2, h], + raw=False, ) - self.model.add_initializer(shape_tensor, self.this_graph_name) reshape_node = helper.make_node( "Reshape", @@ -321,13 +315,12 @@ def create_attention_node( qkv_bias = np.zeros([3, hidden_size], dtype=np.float32) qkv_bias_dim = 3 * hidden_size - bias = helper.make_tensor( + self.add_initializer( name=attention_node_name + "_qkv_bias", data_type=TensorProto.FLOAT, dims=[qkv_bias_dim], - vals=qkv_bias.flatten().tolist(), + vals=qkv_bias, ) - self.model.add_initializer(bias, self.this_graph_name) if is_self_attention: if not self.enable_packed_qkv: @@ -375,6 +368,476 @@ def create_attention_node( self.increase_counter(counter_name) return attention_node + def create_attention_node_lora( + self, + q_matmul_add: NodeProto, + k_matmul_add: NodeProto, + v_matmul_add: NodeProto, + num_heads: int, + hidden_size: int, + input: str, + output: str, + ) -> Union[NodeProto, None]: + """Create an Attention node. + + Args: + q_matmul (NodeProto): MatMul node in fully connection for Q + k_matmul (NodeProto): MatMul node in fully connection for K + v_matmul (NodeProto): MatMul node in fully connection for V + num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning. + hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning. + input (str): input name + output (str): output name + + Returns: + Union[NodeProto, None]: the node created or None if failed. + """ + is_self_attention = not self.is_cross_attention + + q_matmul = self.model.match_parent(q_matmul_add, "MatMul", 0) + k_matmul = self.model.match_parent(k_matmul_add, "MatMul", 0) + v_matmul = self.model.match_parent(v_matmul_add, "MatMul", 0) + + q_lora_nodes = self.match_lora_path(q_matmul_add) + if q_lora_nodes is None: + return None + (q_lora_last_node, q_lora_matmul_1) = q_lora_nodes + + k_lora_nodes = self.match_lora_path(k_matmul_add) + if k_lora_nodes is None: + return None + (k_lora_last_node, k_lora_matmul_1) = k_lora_nodes + + v_lora_nodes = self.match_lora_path(v_matmul_add) + if v_lora_nodes is None: + return None + (v_lora_last_node, v_lora_matmul_1) = v_lora_nodes + + if is_self_attention: + if q_matmul.input[0] != input or k_matmul.input[0] != input or v_matmul.input[0] != input: + logger.debug( + "For self attention, input hidden state for q and k/v shall be same. Got %s, %s, %s", + q_matmul.input[0], + k_matmul.input[0], + v_matmul.input[0], + ) + return None + + if ( + q_lora_matmul_1.input[0] != input + or k_lora_matmul_1.input[0] != input + or v_lora_matmul_1.input[0] != input + ): + logger.debug( + "For self attention, input hidden state for LoRA q and k/v weights shall be same. Got %s, %s, %s", + q_lora_matmul_1.input[0], + k_lora_matmul_1.input[0], + v_lora_matmul_1.input[0], + ) + return None + else: + if q_matmul.input[0] != input or (k_matmul.input[0] != v_matmul.input[0]) or (k_matmul.input[0] == input): + logger.debug( + "For cross attention, input hidden state for q and k/v shall be different. Got %s, %s, %s", + q_matmul.input[0], + k_matmul.input[0], + v_matmul.input[0], + ) + return None + + if ( + q_lora_matmul_1.input[0] != input + or (k_lora_matmul_1.input[0] != v_lora_matmul_1.input[0]) + or (k_matmul.input[0] == input) + ): + logger.debug( + ( + "For cross attention, input hidden state for LoRA q and k/v weights shall be different. " + "Got %s, %s, %s" + ), + q_lora_matmul_1.input[0], + k_lora_matmul_1.input[0], + v_lora_matmul_1.input[0], + ) + return None + + if hidden_size > 0 and (hidden_size % num_heads) != 0: + logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}") + return None + + q_weight = self.model.get_initializer(q_matmul.input[1]) + k_weight = self.model.get_initializer(k_matmul.input[1]) + v_weight = self.model.get_initializer(v_matmul.input[1]) + if not (q_weight and k_weight and v_weight): + return None + + # Sometimes weights are stored in fp16 + if q_weight.data_type == 10: + logger.debug("weights are in fp16. Please run fp16 conversion after optimization") + return None + + qw = NumpyHelper.to_array(q_weight) + kw = NumpyHelper.to_array(k_weight) + vw = NumpyHelper.to_array(v_weight) + logger.debug(f"qw={qw.shape} kw={kw.shape} vw={vw.shape} hidden_size={hidden_size}") + + # assert q and k have same shape as expected + if is_self_attention: + if qw.shape != kw.shape or qw.shape != vw.shape: + return None + + qw_in_size = qw.shape[0] + + if hidden_size > 0 and hidden_size != qw_in_size: + raise ValueError( + f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). " + "Please provide a correct input hidden size or pass in 0" + ) + + # All the matrices can have the same shape or q, k matrics can have the same shape with v being different + # For 2d weights, the shapes would be [in_size, out_size]. + # For 3d weights, shape would be [in_size, a, b] where a*b = out_size + qw_out_size = int(np.prod(qw.shape[1:])) + + if self.enable_packed_qkv: + attention_node_name = self.model.create_node_name("MultiHeadAttention") + + c = qw_in_size + n = num_heads + h = qw_out_size // num_heads + + # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 3, H] shape + qkv_weight = np.dstack([qw.reshape(c, n, h), kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape( + c, n * 3 * h + ) + + matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV") + self.add_initializer( + name=matmul_node_name + "_weight", + data_type=TensorProto.FLOAT, + dims=[qkv_weight.shape[0], qkv_weight.shape[1]], + vals=qkv_weight, + ) + + matmul_node = helper.make_node( + "MatMul", + inputs=[k_matmul.input[0], matmul_node_name + "_weight"], + outputs=[matmul_node_name + "_out"], + name=matmul_node_name, + ) + self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name + + # Do the same thing with the LoRA weights, but don't constant fold the result. The goal is to allow + # the Q/K/V weights to be changed without having to re-run the optimizer. + lora_weight_shape_tensor_name = q_lora_last_node.name + "_reshape_shape" + + self.add_initializer( + name=lora_weight_shape_tensor_name, + data_type=TensorProto.INT64, + dims=[4], + vals=[0, 0, n, h], + raw=False, + ) + + # Reshape the LoRA Q weights + q_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_Q") + q_lora_reshape_node = helper.make_node( + "Reshape", + inputs=[q_lora_last_node.output[0], lora_weight_shape_tensor_name], + outputs=[q_lora_reshape_node_name + "_out"], + name=q_lora_reshape_node_name, + ) + self.node_name_to_graph_name[q_lora_reshape_node.name] = self.this_graph_name + + # Reshape the LoRA K weights + k_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_K") + k_lora_reshape_node = helper.make_node( + "Reshape", + inputs=[k_lora_last_node.output[0], lora_weight_shape_tensor_name], + outputs=[k_lora_reshape_node_name + "_out"], + name=k_lora_reshape_node_name, + ) + self.node_name_to_graph_name[k_lora_reshape_node.name] = self.this_graph_name + + # Reshape the LoRA V weights + v_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_V") + v_lora_reshape_node = helper.make_node( + "Reshape", + inputs=[v_lora_last_node.output[0], lora_weight_shape_tensor_name], + outputs=[v_lora_reshape_node_name + "_out"], + name=v_lora_reshape_node_name, + ) + self.node_name_to_graph_name[v_lora_reshape_node.name] = self.this_graph_name + + # Concat the reshaped LoRA Q/K/V weights together on the third axis + qkv_lora_concat_node_name = self.model.create_node_name("Concat", name_prefix="Concat_LoRA_QKV") + qkv_lora_concat_node = helper.make_node( + "Concat", + inputs=[ + q_lora_reshape_node.output[0], + k_lora_reshape_node.output[0], + v_lora_reshape_node.output[0], + ], + outputs=[qkv_lora_concat_node_name + "_out"], + name=qkv_lora_concat_node_name, + ) + qkv_lora_concat_node.attribute.extend([helper.make_attribute("axis", 3)]) + self.node_name_to_graph_name[qkv_lora_concat_node.name] = self.this_graph_name + + # Reshape the LoRA concatenated weights to [..., n * 3 * h] + reshaped_lora_weights_shape_tensor_name = qkv_lora_concat_node.name + "_reshape_shape" + self.add_initializer( + name=reshaped_lora_weights_shape_tensor_name, + data_type=TensorProto.INT64, + dims=[3], + vals=[0, 0, n * 3 * h], + raw=False, + ) + + qkv_lora_reshaped_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_QKV") + qkv_lora_reshaped_node = helper.make_node( + "Reshape", + inputs=[qkv_lora_concat_node.output[0], reshaped_lora_weights_shape_tensor_name], + outputs=[qkv_lora_reshaped_node_name + "_out"], + name=qkv_lora_reshaped_node_name, + ) + self.node_name_to_graph_name[qkv_lora_reshaped_node.name] = self.this_graph_name + + # Add the LoRA Q/K/V weights to the base Q/K/V weights + add_weights_node_name = self.model.create_node_name("Add", name_prefix="Add_Weights_QKV") + add_weights_node = helper.make_node( + "Add", + inputs=[qkv_lora_reshaped_node.output[0], matmul_node.output[0]], + outputs=[add_weights_node_name + "_out"], + name=add_weights_node_name, + ) + self.node_name_to_graph_name[add_weights_node.name] = self.this_graph_name + + # Finally, reshape the concatenated Q/K/V result to 5D + shape_tensor_name = add_weights_node_name + "_reshape_shape" + self.add_initializer( + name=shape_tensor_name, + data_type=TensorProto.INT64, + dims=[5], + vals=[0, 0, n, 3, h], + raw=False, + ) + + reshape_node = helper.make_node( + "Reshape", + inputs=[add_weights_node.output[0], shape_tensor_name], + outputs=[attention_node_name + "_qkv_input"], + name=add_weights_node_name + "_reshape", + ) + self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name + + self.nodes_to_add.extend( + [ + matmul_node, + q_lora_reshape_node, + k_lora_reshape_node, + v_lora_reshape_node, + qkv_lora_concat_node, + qkv_lora_reshaped_node, + add_weights_node, + reshape_node, + ] + ) + self.nodes_to_remove.extend([q_matmul, k_matmul, v_matmul, q_matmul_add, k_matmul_add, v_matmul_add]) + else: + # TODO: Support non-packed QKV + return None + else: # cross attention + attention_node_name = self.model.create_node_name("MultiHeadAttention") + if self.enable_packed_kv: + if kw.shape != vw.shape: + return None + + kw_in_size = kw.shape[0] + vw_in_size = vw.shape[0] + assert kw_in_size == vw_in_size + + qw_out_size = qw.shape[1] + kw_out_size = kw.shape[1] + vw_out_size = vw.shape[1] + assert qw_out_size == vw_out_size and kw_out_size == vw_out_size + + c = kw_in_size + n = num_heads + h = kw_out_size // num_heads + + # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 2, H] shape + kv_weight = np.dstack([kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(c, n * 2 * h) + + matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV") + self.add_initializer( + name=matmul_node_name + "_weight", + data_type=TensorProto.FLOAT, + dims=[kv_weight.shape[0], kv_weight.shape[1]], + vals=kv_weight, + ) + + matmul_node = helper.make_node( + "MatMul", + inputs=[k_matmul.input[0], matmul_node_name + "_weight"], + outputs=[matmul_node_name + "_out"], + name=matmul_node_name, + ) + self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name + + # Do the same thing with the LoRA weights, but don't constant fold the result. The goal is to allow + # the Q/K/V weights to be changed without having to re-run the optimizer. + kv_lora_weight_shape_tensor_name = q_lora_last_node.name + "_reshape_shape" + self.add_initializer( + name=kv_lora_weight_shape_tensor_name, + data_type=TensorProto.INT64, + dims=[4], + vals=[0, 0, n, h], + raw=False, + ) + + # Reshape the LoRA K weights + k_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_K") + k_lora_reshape_node = helper.make_node( + "Reshape", + inputs=[k_lora_last_node.output[0], kv_lora_weight_shape_tensor_name], + outputs=[k_lora_reshape_node_name + "_out"], + name=k_lora_reshape_node_name, + ) + self.node_name_to_graph_name[k_lora_reshape_node.name] = self.this_graph_name + + # Reshape the LoRA V weights + v_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_V") + v_lora_reshape_node = helper.make_node( + "Reshape", + inputs=[v_lora_last_node.output[0], kv_lora_weight_shape_tensor_name], + outputs=[v_lora_reshape_node_name + "_out"], + name=v_lora_reshape_node_name, + ) + self.node_name_to_graph_name[v_lora_reshape_node.name] = self.this_graph_name + + # Concat the reshaped LoRA K/V weights together on the third axis + kv_lora_concat_node_name = self.model.create_node_name("Concat", name_prefix="Concat_LoRA_KV") + kv_lora_concat_node = helper.make_node( + "Concat", + inputs=[k_lora_reshape_node.output[0], v_lora_reshape_node.output[0]], + outputs=[kv_lora_concat_node_name + "_out"], + name=kv_lora_concat_node_name, + ) + kv_lora_concat_node.attribute.extend([helper.make_attribute("axis", 3)]) + self.node_name_to_graph_name[kv_lora_concat_node.name] = self.this_graph_name + + # Reshape the LoRA concatenated weights to [..., n * 2 * h] + reshaped_kv_lora_weights_shape_tensor_name = kv_lora_concat_node.name + "_reshape_shape" + self.add_initializer( + name=reshaped_kv_lora_weights_shape_tensor_name, + data_type=TensorProto.INT64, + dims=[3], + vals=[0, 0, n * 2 * h], + raw=False, + ) + + kv_lora_reshaped_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_KV") + kv_lora_reshaped_node = helper.make_node( + "Reshape", + inputs=[kv_lora_concat_node.output[0], reshaped_kv_lora_weights_shape_tensor_name], + outputs=[kv_lora_reshaped_node_name + "_out"], + name=kv_lora_reshaped_node_name, + ) + self.node_name_to_graph_name[kv_lora_reshaped_node.name] = self.this_graph_name + + # Add the LoRA K/V weights to the base K/V weights + add_kv_weights_node_name = self.model.create_node_name("Add", name_prefix="Add_Weights_KV") + add_kv_weights_node = helper.make_node( + "Add", + inputs=[kv_lora_reshaped_node.output[0], matmul_node.output[0]], + outputs=[add_kv_weights_node_name + "_out"], + name=add_kv_weights_node_name, + ) + self.node_name_to_graph_name[add_kv_weights_node.name] = self.this_graph_name + + # Finally, reshape the concatenated K/V result to 5D + shape_tensor_name = add_kv_weights_node_name + "_reshape_shape" + self.add_initializer( + name=shape_tensor_name, + data_type=TensorProto.INT64, + dims=[5], + vals=[0, 0, n, 2, h], + raw=False, + ) + + reshape_node = helper.make_node( + "Reshape", + inputs=[add_kv_weights_node.output[0], shape_tensor_name], + outputs=[attention_node_name + "_kv_input"], + name=add_kv_weights_node_name + "_reshape", + ) + self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name + self.nodes_to_add.extend( + [ + matmul_node, + k_lora_reshape_node, + v_lora_reshape_node, + kv_lora_concat_node, + kv_lora_reshaped_node, + add_kv_weights_node, + reshape_node, + ] + ) + self.nodes_to_remove.extend([k_matmul, v_matmul, k_matmul_add, v_matmul_add]) + else: + # TODO: Support non-packed KV + return None + + # No bias, use zeros + qkv_bias = np.zeros([3, hidden_size], dtype=np.float32) + qkv_bias_dim = 3 * hidden_size + self.add_initializer( + name=attention_node_name + "_qkv_bias", + data_type=TensorProto.FLOAT, + dims=[qkv_bias_dim], + vals=qkv_bias, + ) + + if is_self_attention: + if not self.enable_packed_qkv: + # TODO: Support non-packed QKV + return None + else: + attention_inputs = [attention_node_name + "_qkv_input"] + else: + if not self.enable_packed_kv: + # TODO: Support non-packed QKV + return None + else: + attention_inputs = [ + q_matmul_add.output[0], + attention_node_name + "_kv_input", + ] + + attention_node = helper.make_node( + "Attention" if (is_self_attention and not self.enable_packed_qkv) else "MultiHeadAttention", + inputs=attention_inputs, + outputs=[output], + name=attention_node_name, + ) + attention_node.domain = "com.microsoft" + attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) + + counter_name = ( + "Attention (self attention)" + if is_self_attention and not self.enable_packed_qkv + else "MultiHeadAttention ({})".format( + "self attention with packed qkv" + if self.enable_packed_qkv + else "cross attention with packed kv" + if self.enable_packed_kv + else "cross attention" + ) + ) + self.increase_counter(counter_name) + return attention_node + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): node_before_layernorm = self.model.match_parent(normalize_node, "Add", 0) @@ -397,30 +860,62 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): return match_qkv = self.match_qkv_torch1(root_input, skip_add) or self.match_qkv_torch2(root_input, skip_add) - if match_qkv is None: - return - - is_torch2, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v = match_qkv - - attention_last_node = reshape_qkv - - q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2) - if q_num_heads <= 0: - logger.debug("fuse_attention: failed to detect num_heads") - return + if match_qkv is not None: + is_torch2, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v = match_qkv + + attention_last_node = reshape_qkv + + q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2) + if q_num_heads <= 0: + logger.debug("fuse_attention: failed to detect num_heads") + return + + # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads + new_node = self.create_attention_node( + matmul_q, + matmul_k, + matmul_v, + q_num_heads, + q_hidden_size, + input=normalize_node.output[0], + output=attention_last_node.output[0], + ) + if new_node is None: + return + else: + # Check if we have a LoRA pattern + match_qkv = self.match_qkv_torch1_lora(root_input, skip_add) or self.match_qkv_torch2_lora( + root_input, skip_add + ) + if match_qkv is None: + return + + is_torch2, reshape_qkv, transpose_qkv, reshape_q, matmul_add_q, matmul_add_k, matmul_add_v = match_qkv + + attention_last_node = reshape_qkv + + q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2) + if q_num_heads <= 0: + logger.debug("fuse_attention: failed to detect num_heads") + return + + # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads + new_node = self.create_attention_node_lora( + matmul_add_q, + matmul_add_k, + matmul_add_v, + q_num_heads, + q_hidden_size, + input=normalize_node.output[0], + output=attention_last_node.output[0], + ) + if new_node is None: + return - # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads - new_node = self.create_attention_node( - matmul_q, - matmul_k, - matmul_v, - q_num_heads, - q_hidden_size, - input=normalize_node.output[0], - output=attention_last_node.output[0], - ) - if new_node is None: - return + q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2) + if q_num_heads <= 0: + logger.debug("fuse_attention: failed to detect num_heads") + return self.nodes_to_add.append(new_node) self.node_name_to_graph_name[new_node.name] = self.this_graph_name @@ -530,3 +1025,146 @@ def match_qkv_torch2(self, root_input, skip_add): return None return True, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v + + def match_qkv_torch1_lora(self, root_input, skip_add): + """Match Q, K and V paths exported by PyTorch 1 that contains LoRA patterns.*""" + another_input = 1 if skip_add.input[0] == root_input else 0 + qkv_nodes = self.model.match_parent_path( + skip_add, + ["Add", "Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [another_input, 0, None, None, 0, 0, 0], + ) + if qkv_nodes is None: + return None + + (_, _, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes + + # No bias. For cross-attention, the input of the MatMul is encoder_hidden_states graph input. + v_nodes = self.model.match_parent_path(matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add"], [1, 0, 0, 0]) + if v_nodes is None: + logger.debug("fuse_attention: failed to match LoRA v path") + return None + (_, _, _, matmul_add_v) = v_nodes + + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0]) + if qk_nodes is not None: + (_softmax_qk, _mul_qk, matmul_qk) = qk_nodes + else: + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0]) + if qk_nodes is not None: + (_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes + else: + logger.debug("fuse_attention: failed to match LoRA qk path") + return None + + q_nodes = self.model.match_parent_path(matmul_qk, ["Reshape", "Transpose", "Reshape", "Add"], [0, 0, 0, 0]) + if q_nodes is None: + logger.debug("fuse_attention: failed to match LoRA q path") + return None + (_, _transpose_q, reshape_q, matmul_add_q) = q_nodes + + k_nodes = self.model.match_parent_path( + matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add"], [1, 0, 0, 0, 0] + ) + if k_nodes is None: + logger.debug("fuse_attention: failed to match LoRA k path") + return None + + (_, _, _, _, matmul_add_k) = k_nodes + + return False, reshape_qkv, transpose_qkv, reshape_q, matmul_add_q, matmul_add_k, matmul_add_v + + def match_qkv_torch2_lora(self, root_input, skip_add): + """Match Q, K and V paths exported by PyTorch 2 that contains LoRA patterns.*""" + another_input = 1 if skip_add.input[0] == root_input else 0 + qkv_nodes = self.model.match_parent_path( + skip_add, + ["Add", "Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [another_input, 0, None, None, 0, 0], + ) + if qkv_nodes is None: + return None + + (_, _, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes + + v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add"], [1, 0, 0]) + if v_nodes is None: + logger.debug("fuse_attention: failed to match LoRA v path") + return None + (_, _, matmul_add_v) = v_nodes + + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0]) + if qk_nodes is not None: + (_softmax_qk, matmul_qk) = qk_nodes + else: + logger.debug("fuse_attention: failed to match LoRA qk path") + return None + + q_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "Add"], [0, None, 0, 0]) + if q_nodes is None: + logger.debug("fuse_attention: failed to match LoRA q path") + return None + (mul_q, _transpose_q, reshape_q, matmul_add_q) = q_nodes + + k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "Add"], [1, None, 0, 0]) + if k_nodes is None: + logger.debug("fuse_attention: failed to match LoRA k path") + return None + + (_mul_k, _, _, matmul_add_k) = k_nodes + + # The scalar for Q and K is sqrt(1.0/sqrt(head_size)). + mul_q_nodes = self.model.match_parent_path( + mul_q, + ["Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape", "Transpose", "Reshape"], + [None, 0, 1, 0, 0, 0, 0, 0], + ) + if mul_q_nodes is None or mul_q_nodes[-1] != reshape_q: + logger.debug("fuse_attention: failed to match LoRA mul_q path") + return None + + return True, reshape_qkv, transpose_qkv, reshape_q, matmul_add_q, matmul_add_k, matmul_add_v + + def match_lora_path( + self, + add_node: NodeProto, + ): + # Lora paths can look like one of the following options: + # MatMul -> MatMul -> Add + # MatMul -> MatMul -> Mul -> Add + # MatMul -> MatMul -> Mul -> Mul -> Add + + # Try matching MatMul -> MatMul -> Add + lora_nodes = self.model.match_parent_path( + add_node, + ["MatMul", "MatMul"], + [1, 0], + ) + + if lora_nodes is not None: + (lora_matmul_2_node, lora_matmul_1_node) = lora_nodes + return (lora_matmul_2_node, lora_matmul_1_node) + + # Try matching MatMul -> MatMul -> Mul -> Add + lora_nodes = self.model.match_parent_path( + add_node, + ["Mul", "MatMul", "MatMul"], + [1, 0, 0], + ) + + if lora_nodes is not None: + (lora_mul_node, _, lora_matmul_1_node) = lora_nodes + return (lora_mul_node, lora_matmul_1_node) + + # Try matching MatMul -> MatMul -> Mul -> Mul -> Add + lora_nodes = self.model.match_parent_path( + add_node, + ["Mul", "Mul", "MatMul", "MatMul"], + [1, 0, 0, 0], + ) + + if lora_nodes is not None: + (lora_mul_node, _, _, lora_matmul_1_node) = lora_nodes + return (lora_mul_node, lora_matmul_1_node) + + return None diff --git a/onnxruntime/python/tools/transformers/fusion_attention_vae.py b/onnxruntime/python/tools/transformers/fusion_attention_vae.py index e91a8a61fc..151c04f933 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_vae.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_vae.py @@ -170,26 +170,23 @@ def create_attention_node( qkv_bias = np.stack((q_bias, k_bias, v_bias), axis=0) qkv_bias_dim = 3 * q_bias_shape - weight = helper.make_tensor( + self.add_initializer( name=attention_node_name + "_qkv_weight", data_type=TensorProto.FLOAT, dims=[qw_in_size, qkv_weight_dim], - vals=qkv_weight.flatten().tolist(), + vals=qkv_weight, ) - self.model.add_initializer(weight, self.this_graph_name) - # No bias, use zeros qkv_bias = np.zeros([3, hidden_size], dtype=np.float32) qkv_bias_dim = 3 * hidden_size - bias = helper.make_tensor( + self.add_initializer( name=attention_node_name + "_qkv_bias", data_type=TensorProto.FLOAT, dims=[qkv_bias_dim], - vals=qkv_bias.flatten().tolist(), + vals=qkv_bias, ) - self.model.add_initializer(bias, self.this_graph_name) attention_inputs = [ input_name, diff --git a/onnxruntime/python/tools/transformers/fusion_bart_attention.py b/onnxruntime/python/tools/transformers/fusion_bart_attention.py index 513c68a29d..71801401e9 100644 --- a/onnxruntime/python/tools/transformers/fusion_bart_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_bart_attention.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------- import logging +import numpy as np from fusion_attention import AttentionMask, FusionAttention from onnx import TensorProto, helper from onnx_model import OnnxModel @@ -259,8 +260,12 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): empty_bias_name = "empty_bias" empty_tensor = self.model.get_initializer(empty_bias_name) if empty_tensor is None: - empty_tensor = helper.make_tensor(empty_bias_name, TensorProto.FLOAT, [bias_dim], [0.0] * bias_dim) - self.model.add_initializer(empty_tensor, self.this_graph_name) + self.add_initializer( + empty_bias_name, + TensorProto.FLOAT, + dims=[bias_dim], + vals=np.array([0.0] * bias_dim, dtype=np.float32), + ) add_name = self.model.create_node_name("Add") add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k_1.name], add_name) diff --git a/onnxruntime/python/tools/transformers/fusion_base.py b/onnxruntime/python/tools/transformers/fusion_base.py index d53a2f4ba4..67f4f0b55c 100644 --- a/onnxruntime/python/tools/transformers/fusion_base.py +++ b/onnxruntime/python/tools/transformers/fusion_base.py @@ -4,9 +4,10 @@ # -------------------------------------------------------------------------- from collections import defaultdict from logging import getLogger -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Sequence, Union -from onnx import NodeProto +import numpy as np +from onnx import NodeProto, helper from onnx_model import OnnxModel logger = getLogger(__name__) @@ -86,3 +87,51 @@ def apply(self): self.model.prune_graph() elif self.nodes_to_remove or self.nodes_to_add: self.model.update_graph() + + def add_initializer(self, name: str, data_type: int, dims: Sequence[int], vals: Any, raw: bool = True): + if raw: + np_type = helper.tensor_dtype_to_np_dtype(data_type) + if not isinstance(vals, np.ndarray): + bytes = np.array(vals, dtype=np_type).tobytes() + else: + bytes = vals.astype(np_type).tobytes() + tensor = helper.make_tensor( + name=name, + data_type=data_type, + dims=dims, + vals=bytes, + raw=True, + ) + else: + tensor = helper.make_tensor( + name=name, + data_type=data_type, + dims=dims, + vals=vals, + raw=False, + ) + + self.model.add_initializer(tensor, self.this_graph_name) + return tensor + + def add_nodes_to_remove(self, nodes: List[NodeProto]): + # Some nodes are shared between paths (e.g. rotary embedding nodes in the Q and K paths). + # When path A is fused, its shared nodes are added to `self.nodes_to_remove`. But when path B + # is fused, its shared nodes are also added to `self.nodes_to_remove`. When the nodes are + # iteratively removed from `self.nodes_to_remove`, path A's shared nodes are removed first. + # Since path A's shared nodes are removed, path B's shared nodes are not removed because they + # were previously removed for path A. This causes an error to print in remove_node that a node + # has failed to be removed. + # + # To avoid this error, we pre-emptively check if the shared nodes are already in `self.nodes_to_remove`. + # We could alternatively convert `self.nodes_to_remove` to a set to avoid this issue, but there could + # be scenarios where the nodes need to be removed in a specific order and converting to a set would + # lose this order. + for node in nodes: + if node not in self.nodes_to_remove: + self.nodes_to_remove.append(node) + + def add_nodes_to_remove_with_nodes_to_keep(self, nodes: List[NodeProto], nodes_to_keep: List[NodeProto]): + for node in nodes: + if node not in self.nodes_to_remove and node not in nodes_to_keep: + self.nodes_to_remove.append(node) diff --git a/onnxruntime/python/tools/transformers/fusion_embedlayer.py b/onnxruntime/python/tools/transformers/fusion_embedlayer.py index a20febb9f0..bc38399e3c 100644 --- a/onnxruntime/python/tools/transformers/fusion_embedlayer.py +++ b/onnxruntime/python/tools/transformers/fusion_embedlayer.py @@ -378,7 +378,7 @@ def check_embedding(self, word_embedding_gather, segment_embedding_gather, posit logger.info("Cannot fuse EmbedLayerNormalization: segment embedding table is not expected") return False - # In normal case, word embeding table is the largest, and segment embedding table is the smallest, while postion embedding table is in between. + # In normal case, word embedding table is the largest, and segment embedding table is the smallest, while position embedding table is in between. # TODO: use other information (like initializer names) to identify different embedding weights automatically. if word_embedding_table.shape[0] <= position_embedding_table.shape[0]: logger.warning( @@ -430,6 +430,7 @@ def create_fused_node( segment_embedding_gather: Union[None, NodeProto], position_ids: Optional[str] = None, embedding_sum_output=False, + embedding_sum_name=None, ): """Create an EmbedLayerNormalization node. Note that segment embedding is optional. @@ -487,7 +488,8 @@ def create_fused_node( embed_node_outputs = [node_name + "_output", node_name + "_dummy_mask_index"] if embedding_sum_output: - embed_node_outputs.append(node_name + "_embedding_sum") + name = embedding_sum_name if embedding_sum_name is not None else node_name + "_embedding_sum" + embed_node_outputs.append(name) embed_node = helper.make_node( "EmbedLayerNormalization", @@ -522,19 +524,8 @@ def finish_fusion(self, layernorm, embed_node): # use prune graph to remove nodes that is not needed self.prune_graph = True - def is_embedding_sum_needed(self, add_before_layer_norm): - """Check that Add before layer norm has an output to add before next layernorm - - Args: - add_before_layer_norm (NodeProto): Add before any LayerNormalization node in topological order of graph - - Returns: - bool: whether there is an extra output needed out of embed layer norm node - """ - - nodes = self.model.get_children(add_before_layer_norm) - - return len(nodes) > 1 + def is_skip_layer_norm_with_sum_output(self, node): + return (node.op_type == "SkipLayerNormalization") and len(node.output) > 3 and len(node.output[3]) > 0 def fuse_gpt2( self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node, optional_segment_gather=None @@ -570,21 +561,31 @@ def fuse_gpt2( if not self.check_embedding(word_embedding_gather, None, position_embedding_gather): return False - # If the add_before_layernorm node is an Add node, then the add_output output is the first index - # output of this node. - - # If the add_before_layernorm node is SkipLayerNormalization node, then the add_output output + # If layernorm node is SkipLayerNormalization, we need look at its optional fourth output. + # If the add_before_layernorm node is an Add node, then the add_output output is the first output of this node. + # If the add_before_layernorm node is a SkipLayerNormalization node, then the add_output output # is the (optional) fourth index output of this node. - add_output = None - optional_embedding_sum_output = False - if (add_before_layernorm.op_type == "Add" and self.is_embedding_sum_needed(add_before_layernorm)) or ( - add_before_layernorm.op_type == "SkipLayerNormalization" and len(add_before_layernorm.output) >= 4 - ): - optional_embedding_sum_output = True - add_output = ( - add_before_layernorm.output[0] - if add_before_layernorm.op_type == "Add" - else add_before_layernorm.output[3] + # When add_before_layernorm is SkipLayerNormalization, add_before_layernorm and layernorm are same node. + if layernorm.op_type == "SkipLayerNormalization": + need_embedding_sum_output = self.is_skip_layer_norm_with_sum_output(layernorm) + sum_output_index = 3 + node_with_sum_output = layernorm + sum_output = layernorm.output[3] if need_embedding_sum_output else None + is_sum_graph_output = (sum_output is not None) and (self.model.find_graph_output(sum_output) is not None) + else: # layernorm.op_type == "LayerNormalization" + node_with_sum_output = add_before_layernorm + sum_output_index = 0 if add_before_layernorm.op_type == "Add" else 3 + sum_output = ( + add_before_layernorm.output[sum_output_index] + if len(add_before_layernorm.output) > sum_output_index + else None + ) + is_sum_graph_output = (sum_output is not None) and (self.model.find_graph_output(sum_output) is not None) + is_sum_used_by_multiple_nodes = ( + sum_output and (sum_output in input_name_to_nodes) and len(input_name_to_nodes[sum_output]) > 1 + ) + need_embedding_sum_output = (sum_output is not None) and ( + add_before_layernorm.op_type != "Add" or is_sum_graph_output or is_sum_used_by_multiple_nodes ) # make the fused node @@ -595,14 +596,16 @@ def fuse_gpt2( position_embedding_gather, optional_segment_gather, position_ids, - optional_embedding_sum_output, + embedding_sum_output=need_embedding_sum_output, + embedding_sum_name=sum_output if is_sum_graph_output else None, ) - # direct the output to another add too - self.model.replace_input_of_all_nodes(layernorm.output[0], embed_node.output[0]) - if optional_embedding_sum_output: - self.model.replace_input_of_all_nodes(add_output, embed_node.output[2]) + if need_embedding_sum_output: + node_with_sum_output.output[sum_output_index] = "_no_use__to_be_removed_" + if not is_sum_graph_output: + self.model.replace_input_of_all_nodes(sum_output, embed_node.output[2]) + self.finish_fusion(layernorm, embed_node) return True def fuse_distilbert(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node): @@ -707,9 +710,14 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): gather_0_path = self.model.match_parent_path(node, ["Gather"], [0]) gather_1_path = self.model.match_parent_path(node, ["Gather"], [1]) if gather_0_path is None and gather_1_path is not None: + if first_add_path is None: + return add_before_layernorm = first_add_path[0] optional_segment_gather = gather_1_path[0] elif gather_0_path is not None and gather_1_path is None: + first_add_path = self.model.match_parent_path(node, ["Add"], [1]) + if first_add_path is None: + return add_before_layernorm = first_add_path[0] optional_segment_gather = gather_0_path[0] else: diff --git a/onnxruntime/python/tools/transformers/fusion_gpt_attention.py b/onnxruntime/python/tools/transformers/fusion_gpt_attention.py index 7b9e758178..a3f98d411e 100644 --- a/onnxruntime/python/tools/transformers/fusion_gpt_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_gpt_attention.py @@ -239,7 +239,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): [0, None, 0, 0, 0, 0, 0], output_name_to_node=output_name_to_node, return_indice=return_indice, - ) # yapf: disable + ) else: qkv_nodes = self.model.match_parent_path( normalize_node, @@ -247,7 +247,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): [None, 0, 0, 0, 0, 0], output_name_to_node=output_name_to_node, return_indice=return_indice, - ) # yapf: disable + ) if qkv_nodes is None: return @@ -361,7 +361,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): "Div", ], [1, 0, 1, 0, 1, 0, 0, 0, 0, 0], - ) # yapf: disable + ) if mask_nodes is None: logger.debug("fuse_attention: failed to match unidirectional mask path") return @@ -414,7 +414,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ), # useless cast and reshape are removed. ], output_name_to_node, - ) # yapf: disable + ) if input_mask_nodes is None: logger.debug("fuse_attention: failed to match input attention mask path") return @@ -437,7 +437,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ), ], output_name_to_node, - ) # yapf: disable + ) if mask_nodes is None: # TODO: match mask path for GPT2LMHeadModel_BeamSearchStep. logger.debug("fuse_attention: failed to match mask path") diff --git a/onnxruntime/python/tools/transformers/fusion_gpt_attention_megatron.py b/onnxruntime/python/tools/transformers/fusion_gpt_attention_megatron.py index 052dd243fd..7eb774b746 100644 --- a/onnxruntime/python/tools/transformers/fusion_gpt_attention_megatron.py +++ b/onnxruntime/python/tools/transformers/fusion_gpt_attention_megatron.py @@ -72,9 +72,7 @@ def fuse_attention_node( self.prune_graph = True def match_mask(self, sub_qk, mul_qk, matmul_qk, layernorm_before_attention): - mask_nodes = self.model.match_parent_path( - sub_qk, ["Mul", "Sub", "Slice", "Slice"], [1, 0, 1, 0] - ) # yapf: disable + mask_nodes = self.model.match_parent_path(sub_qk, ["Mul", "Sub", "Slice", "Slice"], [1, 0, 1, 0]) if mask_nodes is None: logger.debug("fuse_attention: failed to match unidirectional mask path") return None @@ -176,14 +174,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Add", "Add", "MatMul", "Reshape", "Transpose", "MatMul"], [0, 1, None, 0, 0, 0], output_name_to_node=output_name_to_node, - ) # yapf: disable + ) else: qkv_nodes = self.model.match_parent_path( normalize_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, None, 0, 0, 0], output_name_to_node=output_name_to_node, - ) # yapf: disable + ) if qkv_nodes is None: return @@ -223,7 +221,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): "LayerNormalization", ], [1, 1, 0, 0, 0, None, 0], - ) # yapf: disable + ) if v_nodes is None: v_nodes = self.model.match_parent_path( @@ -238,7 +236,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): "SkipLayerNormalization", ], [1, 1, 0, 0, 0, None, 0], - ) # yapf: disable + ) if v_nodes is None: logger.debug("fuse_attention: failed to match v path") diff --git a/onnxruntime/python/tools/transformers/fusion_gpt_attention_no_past.py b/onnxruntime/python/tools/transformers/fusion_gpt_attention_no_past.py index 83fa51dcfa..b217743c4a 100644 --- a/onnxruntime/python/tools/transformers/fusion_gpt_attention_no_past.py +++ b/onnxruntime/python/tools/transformers/fusion_gpt_attention_no_past.py @@ -76,7 +76,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): [0, None, 0, 0, 0, 0, 0], output_name_to_node=output_name_to_node, return_indice=return_indice, - ) # yapf: disable + ) else: qkv_nodes = self.model.match_parent_path( normalize_node, @@ -84,7 +84,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): [None, 0, 0, 0, 0, 0], output_name_to_node=output_name_to_node, return_indice=return_indice, - ) # yapf: disable + ) if qkv_nodes is None: return @@ -116,7 +116,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): matmul_qkv, ["Transpose", "Reshape", "Split", "Reshape", "Gemm", "Reshape"], [1, 0, 0, 0, 0, 0], - ) # yapf: disable + ) if v_nodes is None: logger.debug("fuse_attention: failed to match v path") return @@ -168,7 +168,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): "Div", ], [1, 0, 1, 0, 1, 0, 0, 0, 0, 0], - ) # yapf: disable + ) if mask_nodes is None: logger.debug("fuse_attention: failed to match mask path") return @@ -201,7 +201,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): "Div", ], [0, 0, 0, 1, 0, 0, 0, 0, 0], - ) # yapf: disable + ) if mask_nodes is None: logger.debug("fuse_attention: failed to match mask path") return @@ -225,7 +225,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): mul_qk, ["Slice", "Slice", "Unsqueeze", "Squeeze", "Slice", "Shape", "Div"], [1, 0, 2, 0, 0, 0, 0], - ) # yapf: disable + ) if mask_nodes is None: logger.debug("fuse_attention: failed to match mask path") return diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py index 2cae366d3f..c718d2c27e 100644 --- a/onnxruntime/python/tools/transformers/fusion_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py @@ -82,19 +82,11 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): return instance_norm_scale = self.model.get_constant_value(instance_norm.input[1]) - if instance_norm_scale is None: - return - instance_norm_bias = self.model.get_constant_value(instance_norm.input[2]) - if instance_norm_bias is None: + if instance_norm_scale is None or len(instance_norm_scale.shape) != 1: return - if not ( - len(instance_norm_scale.shape) == 1 - and len(instance_norm_bias.shape) == 1 - and instance_norm_scale.shape == instance_norm_bias.shape - and instance_norm_scale.shape[0] == 32 - ): - logger.info("InstanceNormalization groups=%d", instance_norm_scale.shape[0]) + instance_norm_bias = self.model.get_constant_value(instance_norm.input[2]) + if instance_norm_bias is None or instance_norm_scale.shape != instance_norm_scale.shape: return if not np.allclose(np.ones_like(instance_norm_scale), instance_norm_scale): @@ -104,24 +96,19 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): group_norm_name = self.model.create_node_name("GroupNorm", name_prefix="GroupNorm") - if weight_elements not in [320, 640, 960, 1280, 1920, 2560, 128, 256, 512]: - logger.info("GroupNorm channels=%d", weight_elements) - - gamma = helper.make_tensor( + self.add_initializer( name=group_norm_name + "_gamma", data_type=TensorProto.FLOAT, dims=[weight_elements], - vals=weight.flatten().tolist(), + vals=weight, ) - self.model.add_initializer(gamma, self.this_graph_name) - beta = helper.make_tensor( + self.add_initializer( name=group_norm_name + "_beta", data_type=TensorProto.FLOAT, dims=[bias_elements], - vals=bias.flatten().tolist(), + vals=bias, ) - self.model.add_initializer(beta, self.this_graph_name) last_node = add_node subgraph_nodes = [add_node, weight_mul, reshape_4d, instance_norm, reshape_3d, shape_node] diff --git a/onnxruntime/python/tools/transformers/fusion_layernorm.py b/onnxruntime/python/tools/transformers/fusion_layernorm.py index ec485e0dfa..68d26fc46f 100644 --- a/onnxruntime/python/tools/transformers/fusion_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_layernorm.py @@ -187,7 +187,7 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): ), ], output_name_to_node, - ) # yapf: disable + ) if parent_nodes is None: return diff --git a/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py b/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py index d8ecb65280..141ebb1f95 100644 --- a/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py +++ b/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py @@ -54,13 +54,12 @@ def fuse(self, conv, input_name_to_nodes, output_name_to_node): weight = weight.transpose(0, 2, 3, 1) weight_name = node_name + "_weight_NHWC" - nhwc_weight = helper.make_tensor( + self.add_initializer( name=weight_name, data_type=TensorProto.FLOAT, dims=list(weight.shape), - vals=weight.flatten().tolist(), + vals=weight, ) - self.model.add_initializer(nhwc_weight, self.this_graph_name) weight_transpose_node = None else: weight_transpose_node = self.create_transpose_node(conv.input[1], [0, 2, 3, 1]) diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index 57f0fea99d..b9b92d2fe8 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -26,6 +26,7 @@ def __init__(self, model_type): self.enable_gelu = True self.enable_layer_norm = True self.enable_attention = True + self.enable_rotary_embeddings = True # Use MultiHeadAttention instead of Attention operator. The difference: # (1) Attention has merged weights for Q/K/V projection, which might be faster in some cases since 3 MatMul is @@ -45,6 +46,9 @@ def __init__(self, model_type): self.enable_gemm_fast_gelu = False self.group_norm_channels_last = True + if model_type == "clip": + self.enable_embed_layer_norm = False + # Set default to sequence length for BERT model to use fused attention to speed up. # Note that embed layer normalization will convert 2D mask to 1D when mask type is MaskIndexEnd. self.attention_mask_format = AttentionMaskFormat.AttentionMask @@ -57,6 +61,7 @@ def __init__(self, model_type): if model_type in ["unet", "vae", "clip"]: self.enable_nhwc_conv = True self.enable_group_norm = True + self.enable_skip_group_norm = True self.enable_bias_splitgelu = True self.enable_packed_qkv = True self.enable_packed_kv = True @@ -78,6 +83,8 @@ def parse(args): options.enable_gelu = False if args.disable_layer_norm: options.enable_layer_norm = False + if args.disable_rotary_embeddings: + options.enable_rotary_embeddings = False if args.disable_attention: options.enable_attention = False if args.use_multi_head_attention: @@ -110,6 +117,8 @@ def parse(args): options.enable_nhwc_conv = False if args.disable_group_norm: options.enable_group_norm = False + if args.disable_skip_group_norm: + options.enable_skip_group_norm = False if args.disable_bias_splitgelu: options.enable_bias_splitgelu = False if args.disable_packed_qkv: @@ -244,6 +253,14 @@ def add_arguments(parser: ArgumentParser): ) parser.set_defaults(disable_group_norm=False) + parser.add_argument( + "--disable_skip_group_norm", + required=False, + action="store_true", + help="not fuse Add + GroupNorm to SkipGroupNorm. Only works for model_type=unet or vae", + ) + parser.set_defaults(disable_skip_group_norm=False) + parser.add_argument( "--disable_packed_kv", required=False, @@ -291,3 +308,10 @@ def add_arguments(parser: ArgumentParser): help="Use channels_first (NCHW) instead of channels_last (NHWC) for GroupNorm. Only works for model_type=unet or vae", ) parser.set_defaults(use_group_norm_channels_first=False) + + parser.add_argument( + "--disable_rotary_embeddings", + required=False, + action="store_true", + help="Do not fuse rotary embeddings into RotaryEmbedding op", + ) diff --git a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py new file mode 100644 index 0000000000..de89b35366 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py @@ -0,0 +1,1382 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging +from typing import Optional, Union + +from fusion_attention import FusionAttention +from fusion_base import Fusion +from onnx import FunctionProto, NodeProto, TensorProto, helper, numpy_helper +from onnx_model import OnnxModel + +logger = logging.getLogger(__name__) + + +class FusionRotaryAttention(FusionAttention): + """ + Fuse Attention subgraph with rotary positional embeddings into one MultiHeadAttention node. + """ + + def __init__( + self, + model: OnnxModel, + hidden_size: int, + num_heads: int, + ): + super().__init__( + model, + hidden_size, + num_heads, + use_multi_head_attention=True, + search_op_types=[ + "SimplifiedLayerNormalization", + "SkipSimplifiedLayerNormalization", + "LayerNormalization", + "SkipLayerNormalization", + "Add", + ], + ) + + def create_mha_node( + self, + input: str, + output: str, + q_rotary: NodeProto, + k_rotary: NodeProto, + v_matmul: NodeProto, + attn_mask: str = "", + add_qk: str = "", + past_k: str = "", + past_v: str = "", + present_k: str = "", + present_v: str = "", + scale: Optional[float] = None, + ) -> Union[NodeProto, None]: + assert self.num_heads > 0 + + if self.hidden_size > 0 and (self.hidden_size % self.num_heads) != 0: + logger.debug( + f"fuse_rotary_attention: input hidden size {self.hidden_size} is not a multiple of num of heads {self.num_heads}" + ) + return None + + mha_node_name = self.model.create_node_name("MultiHeadAttention") + mha_inputs = [ + q_rotary.output[0], + k_rotary.output[0], + v_matmul.output[0], + "", # bias + attn_mask, # key_padding_mask + add_qk, # relative_position_bias + past_k, + past_v, + ] + + mha_outputs = [output] + if present_k and present_v: + mha_outputs.extend([present_k, present_v]) + + mha_node = helper.make_node( + "MultiHeadAttention", + inputs=mha_inputs, + outputs=mha_outputs, + name=mha_node_name, + ) + + mha_node.domain = "com.microsoft" + mha_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)]) + if scale is not None: + mha_node.attribute.extend([helper.make_attribute("scale", scale)]) + if self.mask_filter_value is not None: + mha_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))]) + + self.increase_counter("MultiHeadAttention") + return mha_node + + def check_runtime_shape_paths_for_function( + self, + reshape_qkv_2, # Reshape after Transpose + reshape_qkv_1, # Reshape before Transpose + reshape_q_2, # Reshape after RotaryEmbedding + reshape_k_2, # Reshape after RotaryEmbedding + reshape_v_2, # Reshape after Transpose + reshape_v_1, # Reshape before Transpose + add_qk, # Add before Softmax + root_input, # Root input to attention subgraph + ): + # Check #1: check paths for qkv nodes + concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1]) + concat_qkv_1_path = self.model.match_parent_path(reshape_qkv_1, ["Concat"], [1]) + if concat_qkv_2_path is None or concat_qkv_1_path is None: + return False + concat_qkv_2, concat_qkv_1 = concat_qkv_2_path[0], concat_qkv_1_path[0] + + reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + reshape_qkv_1_path_1 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_qkv_1_path_2 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [2, 0, 0]) + if ( + reshape_qkv_2_path_1 is None + or reshape_qkv_2_path_2 is None + or reshape_qkv_1_path_1 is None + or reshape_qkv_1_path_2 is None + ): + return False + + _, gather_1, shape_1 = reshape_qkv_2_path_1 + _, gather_2, shape_2 = reshape_qkv_2_path_2 + + # Check root_input --> Shape --> Gather connection + if shape_1.input[0] != root_input or shape_2.input[0] != root_input: + return False + + # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_qkv_1_path_1 and reshape_qkv_1_path_2 + if reshape_qkv_1_path_1[1].name != gather_1.name or reshape_qkv_1_path_2[1].name != gather_2.name: + return False + + # Check #2: check paths for v nodes + concat_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat"], [1]) + concat_v_1_path = self.model.match_parent_path(reshape_v_1, ["Concat"], [1]) + if concat_v_2_path is None or concat_v_1_path is None: + return False + concat_v_2, concat_v_1 = concat_v_2_path[0], concat_v_1_path[0] + + reshape_v_2_path_1 = self.model.match_parent_path( + concat_v_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0] + ) + reshape_v_2_path_2 = self.model.match_parent_path( + concat_v_2, ["Unsqueeze", "Add", "Gather", "Shape"], [1, 0, 0, 0] + ) + reshape_v_1_path_1 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_v_1_path_2 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + if ( + reshape_v_2_path_1 is None + or reshape_v_2_path_2 is None + or reshape_v_1_path_1 is None + or reshape_v_1_path_2 is None + ): + return False + + # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_1 + # Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_2 + # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_v_1_path_1 and reshape_v_1_path_2 + if ( + reshape_v_2_path_1[2].name != gather_1.name + or reshape_v_2_path_2[2].name != gather_2.name + or reshape_v_1_path_1[1].name != gather_1.name + or reshape_v_1_path_2[1].name != gather_2.name + ): + return False + + # Check #3: check paths for k nodes + concat_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat"], [1]) + if concat_k_2_path is None: + return False + concat_k_2 = concat_k_2_path[0] + + reshape_k_2_path_1 = self.model.match_parent_path( + concat_k_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0] + ) + reshape_k_2_path_2 = self.model.match_parent_path( + concat_k_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 0, 0] + ) + if reshape_k_2_path_1 is None or reshape_k_2_path_2 is None: + return False + + # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_1 + # Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_2 + if reshape_k_2_path_1[2].name != gather_1.name or reshape_k_2_path_2[2].name != gather_2.name: + return False + + # Check #4: check paths for q nodes + concat_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat"], [1]) + if concat_q_2_path is None: + return False + concat_q_2 = concat_q_2_path[0] + + reshape_q_2_path_1 = self.model.match_parent_path( + concat_q_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0] + ) + reshape_q_2_path_2 = self.model.match_parent_path(concat_q_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + if reshape_q_2_path_1 is None or reshape_q_2_path_2 is None: + return False + + # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_1 + # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_2 + if reshape_q_2_path_1[2].name != gather_1.name or reshape_q_2_path_2[1].name != gather_2.name: + return False + + # Check #5: check Mul nodes are the same for q, k, v + mul_q = reshape_q_2_path_1[1] + mul_k = reshape_k_2_path_1[1] + mul_v = reshape_v_2_path_1[1] + gather_1_out = gather_1.output[0] + if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out: + return False + + # Check #6: check paths for attention mask nodes + attn_mask_path_1 = self.model.match_parent_path(add_qk, ["Concat", "Slice", "Slice"], [1, 0, 0]) + attn_mask_path_2 = self.model.match_parent_path(add_qk, ["Cast", "Concat", "Slice", "Slice"], [1, 0, 0, 0]) + if attn_mask_path_1 is not None: + _, slice_qk_2, slice_qk_1 = attn_mask_path_1 + elif attn_mask_path_2 is not None: + _, _, slice_qk_2, slice_qk_1 = attn_mask_path_2 + else: + return False + # Check first input to Slice #1 is 3D attention mask of shape (B,S,T) + if slice_qk_1.input[0] not in {"attn_mask", "attention_mask"}: + return False + + slice_qk_2_path = self.model.match_parent_path( + slice_qk_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0] + ) + slice_qk_1_path_1 = self.model.match_parent_path( + slice_qk_1, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0] + ) + slice_qk_1_path_2 = self.model.match_parent_path(slice_qk_1, ["Unsqueeze"], [1]) + if slice_qk_2_path is None or slice_qk_1_path_1 is None or slice_qk_1_path_2 is None: + return False + + # Check Gather --> Add --> Unsqueeze #3 --> Slice #2 connection for slice_qk_2_path + # Check Gather --> Add --> Unsqueeze #2 --> Slice #1 connection for slice_qk_1_path_1 + if slice_qk_2_path[1].name != slice_qk_1_path_1[1].name or slice_qk_2_path[2].name != slice_qk_1_path_1[2].name: + return False + + # Check Unsqueeze #1 --> Slice #1 connection for slice_qk_1_path_2 + # Check if first input to Add and Unsqueeze #1 is position ids + if slice_qk_1_path_1[1].input[0] != slice_qk_1_path_2[0].input[0]: + return False + + return True + + def check_runtime_shape_paths_for_nodes( + self, + reshape_qkv, # Final reshape before o_proj MatMul + reshape_q, # Reshape before q_proj MatMul + reshape_k, # Reshape before k_proj MatMul + reshape_v, # Reshape before v_proj MatMul + root_input, # Root input to attention subgraph + ): + # Check #1: check paths for qkv nodes + concat_qkv_path = self.model.match_parent_path(reshape_qkv, ["Concat"], [1]) + if concat_qkv_path is None: + return False + concat_qkv = concat_qkv_path[0] + + reshape_qkv_path_1 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_qkv_path_2 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + if reshape_qkv_path_1 is None or reshape_qkv_path_2 is None: + return False + + _, gather_1, shape_1 = reshape_qkv_path_1 + _, gather_2, shape_2 = reshape_qkv_path_2 + + # Check root_input --> Shape --> Gather connection + if shape_1.input[0] != root_input or shape_2.input[0] != root_input: + return False + + # Check #2: check paths for v nodes + concat_v_path = self.model.match_parent_path(reshape_v, ["Concat"], [1]) + if concat_v_path is None: + return False + concat_v = concat_v_path[0] + + reshape_v_path_1 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_v_path_2 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + if reshape_v_path_1 is None or reshape_v_path_2 is None: + return False + + # Check Gather --> Unsqueeze --> Concat --> Reshape connection + if reshape_v_path_1[1].name != gather_1.name or reshape_v_path_2[1].name != gather_2.name: + return False + + # Check #3: check paths for k nodes + concat_k_path = self.model.match_parent_path(reshape_k, ["Concat"], [1]) + if concat_k_path is None: + return False + concat_k = concat_k_path[0] + + reshape_k_path_1 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_k_path_2 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + if reshape_k_path_1 is None or reshape_k_path_2 is None: + return False + + # Check Gather --> Unsqueeze --> Concat --> Reshape connection + if reshape_k_path_1[1].name != gather_1.name or reshape_k_path_2[1].name != gather_2.name: + return False + + # Check #4: check paths for q nodes + concat_q_path = self.model.match_parent_path(reshape_q, ["Concat"], [1]) + if concat_q_path is None: + return False + concat_q = concat_q_path[0] + + reshape_q_path_1 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_q_path_2 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + if reshape_q_path_1 is None or reshape_q_path_2 is None: + return False + + # Check Gather --> Unsqueeze --> Concat --> Reshape connection + if reshape_q_path_1[1].name != gather_1.name or reshape_q_path_2[1].name != gather_2.name: + return False + + return True + + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + if normalize_node.op_type not in {"SkipSimplifiedLayerNormalization", "SkipLayerNormalization", "Add"}: + return + + # qkv_nodes_1 is for LLaMA-2 Microsoft + # qkv_nodes_2 is for LLaMA-2 Hugging Face + # qkv_nodes_3 is for LLaMA-2 distribute Hugging Face model + qkv_nodes = None + qkv_nodes_1 = self.model.match_parent_path( + normalize_node, + ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 0, 0], + ) + qkv_nodes_2 = self.model.match_parent_path( + normalize_node, + ["MatMul", "Reshape", "Transpose", "MatMul"], + [1, 0, 0, 0], + ) + qkv_nodes_3 = self.model.match_parent_path( + normalize_node, + ["AllReduce", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 0, 0, 0, 0], + ) + if qkv_nodes_1 is not None: + _, reshape_qkv_2, _, reshape_qkv_1, matmul_qkv = qkv_nodes_1 + qkv_nodes = qkv_nodes_1 + elif qkv_nodes_2 is not None: + _, reshape_qkv, _, matmul_qkv = qkv_nodes_2 + qkv_nodes = qkv_nodes_2 + elif qkv_nodes_3 is not None: + _, _, reshape_qkv, _, matmul_qkv = qkv_nodes_3 + qkv_nodes = qkv_nodes_3 + else: + logger.debug("fuse_rotary_attention: failed to match qkv nodes") + return + + # v_nodes_1 is for LLaMA-2 Microsoft + # v_nodes_3 is for LLaMA-2 Hugging Face + # v_nodes_4 is for LLaMA-2 70B model + past_v, present_v, past_seq_len = "", "", "" + v_nodes = None + v_nodes_1 = self.model.match_parent_path( + matmul_qkv, + ["Reshape", "Transpose", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 1, 0, 0], + ) + v_nodes_2 = self.model.match_parent_path( + matmul_qkv, + ["Concat", "Transpose", "Reshape", "MatMul"], + [1, 1, 0, 0], + ) + v_nodes_3 = self.model.match_parent_path( + matmul_qkv, + ["Transpose", "Reshape", "MatMul"], + [1, 0, 0], + ) + _, v_nodes_4, _ = self.model.match_parent_paths_all( + matmul_qkv, + [ + ( + ["Reshape", "Expand", "Unsqueeze", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 0, 1, 0, 0], + ), + ( + [ + "Reshape", + "Expand", + "Where", + "Equal", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + ), + ( + [ + "Reshape", + "Expand", + "Where", + "Equal", + "Mul", + "ConstantOfShape", + "Shape", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0], + ), + ( + [ + "Reshape", + "Expand", + "Where", + "ConstantOfShape", + "Shape", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0], + ), + ( + [ + "Reshape", + "Expand", + "Where", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0], + ), + ( + ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 1, 0, 0, 0, 0, 1, 0, 0], + ), + ( + [ + "Reshape", + "Concat", + "Unsqueeze", + "Mul", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 1, 1, 0, 0, 0, 0, 1, 0, 0], + ), + ( + ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 1, 2, 0, 0, 0, 1, 0, 0], + ), + ( + ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 1, 3, 0, 0, 0, 1, 0, 0], + ), + ], + output_name_to_node=None, + ) + if v_nodes_1 is not None: + reshape_v_2, _, concat_v, _, reshape_v_1, matmul_v = v_nodes_1 + v_nodes = v_nodes_1 + + concat_v_path = self.model.match_parent_path( + concat_v, + ["Slice", "Unsqueeze"], + [0, 2], + ) + if concat_v_path is None: + logger.debug("fuse_rotary_attention: failed to match past/present concat in v path") + return + + past_v = concat_v_path[0].input[0] + past_seq_len = concat_v_path[-1].input[0] + present_v = concat_v.output[0] + elif v_nodes_2 is not None: + concat_v, transpose_v, reshape_v, matmul_v = v_nodes_2 + v_nodes = v_nodes_2 + past_v = concat_v.input[0] + present_v = concat_v.output[0] + elif v_nodes_3 is not None: + transpose_v, reshape_v, matmul_v = v_nodes_3 + v_nodes = v_nodes_3 + present_v = transpose_v.output[0] + elif v_nodes_4 is not None and len(v_nodes_4) == 9: + concat_v, transpose_v, reshape_v, matmul_v = v_nodes_4[0][-4:] + v_nodes = v_nodes_4 + past_v = concat_v.input[0] + present_v = concat_v.output[0] + else: + logger.debug("fuse_rotary_attention: failed to match v path") + return + + qk_nodes = self.model.match_parent_path( + matmul_qkv, + ["Softmax", "Add", "Div", "MatMul"], + [0, 0, 0, 0], + ) + add_qk, matmul_qk = None, None + if qk_nodes is not None: + _, add_qk, _, matmul_qk = qk_nodes + else: + logger.debug("fuse_rotary_attention: failed to match qk nodes") + return + + # attn_mask_nodes_1, attn_mask_nodes_2 are for LLaMA-2 Microsoft's 3D attention mask + # attn_mask_nodes_3, attn_mask_nodes_4 are for LLaMA-2 Hugging Face's 2D attention mask + attn_mask, add_qk_str = "", "" + attn_mask_nodes_1 = self.model.match_parent_path( + add_qk, + ["Concat", "Slice", "Slice"], + [1, 0, 0], + ) + attn_mask_nodes_2 = self.model.match_parent_path( + add_qk, + ["Cast", "Concat", "Slice", "Slice"], + [1, 0, 0, 0], + ) + attn_mask_nodes_3 = self.model.match_parent_path( + add_qk, + ["Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], + [1, 0, 2, 1, 0, 0, 0], + ) + attn_mask_nodes_4 = self.model.match_parent_path( + add_qk, + ["Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], + [1, 2, 1, 0, 0, 0], + ) + attn_mask_nodes_5 = self.model.match_parent_path( + add_qk, + ["Expand", "Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], + [1, 0, 0, 2, 1, 0, 0, 0], + ) + attn_mask_nodes_6 = self.model.match_parent_path( + add_qk, + ["Expand", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], + [1, 0, 2, 1, 0, 0, 0], + ) + if attn_mask_nodes_1 is not None: + _, slice_mask_1, slice_mask_2 = attn_mask_nodes_1 + attn_mask = slice_mask_1.output[0] + elif attn_mask_nodes_2 is not None: + _, _, slice_mask_1, slice_mask_2 = attn_mask_nodes_2 + attn_mask = slice_mask_1.output[0] + elif attn_mask_nodes_3 is not None: + # Reshape from (B,1,S,T) to (B,N,S,T) + add_qk_str = self.reshape_add_qk(attn_mask_nodes_3[0].output[0]) + elif attn_mask_nodes_4 is not None: + # Reshape from (B,1,S,T) to (B,N,S,T) + add_qk_str = self.reshape_add_qk(attn_mask_nodes_4[0].output[0]) + elif attn_mask_nodes_5 is not None: + # The mask has already been reshaped to (B,N,S,T) + add_qk_str = attn_mask_nodes_5[0].output[0] + elif attn_mask_nodes_6 is not None: + # The mask has already been reshaped to (B,N,S,T) + add_qk_str = attn_mask_nodes_6[0].output[0] + else: + logger.debug("fuse_rotary_attention: failed to match attention mask nodes") + return + + # k_nodes_1 is for LLaMA-2 Microsoft + # k_nodes_2 is for LLaMA-2 Hugging Face + # k_nodes_4 is for LLaMA-2 70B Hugging Face + past_k, present_k = "", "" + k_nodes = None + k_nodes_1 = self.model.match_parent_path( + matmul_qk, + ["Reshape", "Transpose", "Concat", "Transpose", "RotaryEmbedding", "MatMul"], + [1, 0, 0, 1, 0, 0], + ) + k_nodes_2 = self.model.match_parent_path( + matmul_qk, + ["Transpose", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 0, 0], + ) + k_nodes_3 = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], + [1, 0, 1, 0, 0, 0], + ) + _, k_nodes_4, _ = self.model.match_parent_paths_all( + matmul_qk, + [ + ( + [ + "Transpose", + "Reshape", + "Expand", + "Unsqueeze", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Expand", + "Where", + "Equal", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Expand", + "Where", + "Equal", + "Mul", + "ConstantOfShape", + "Shape", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Expand", + "Where", + "ConstantOfShape", + "Shape", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Expand", + "Where", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Concat", + "Unsqueeze", + "Mul", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 3, 0, 0, 0, 1, 0, 0, 0], + ), + ], + output_name_to_node=None, + ) + if k_nodes_1 is not None: + reshape_k_2, _, concat_k, _, rotary_k, matmul_k = k_nodes_1 + k_nodes = k_nodes_1 + + concat_k_path = self.model.match_parent_path( + concat_k, + ["Slice", "Unsqueeze"], + [0, 2], + ) + if concat_k_path is None: + logger.debug("fuse_rotary_attention: failed to match past/present concat in k path") + return + + past_k = concat_k_path[0].input[0] + shared_past_seq_len = concat_k_path[-1].input[0] + present_k = concat_k.output[0] + + assert past_seq_len == shared_past_seq_len + elif k_nodes_2 is not None: + _, rotary_k, _, reshape_k, matmul_k = k_nodes_2 + k_nodes = k_nodes_2 + present_k = rotary_k.output[0] + elif k_nodes_3 is not None: + _, concat_k, rotary_k, _, reshape_k, matmul_k = k_nodes_3 + k_nodes = k_nodes_3 + past_k = concat_k.input[0] + present_k = concat_k.output[0] + elif k_nodes_4 is not None and len(k_nodes_4) == 9: + reshape_k, matmul_k = k_nodes_4[0][-2:] + concat_k, rotary_k = k_nodes_4[0][-5:-3] + k_nodes = k_nodes_4 + past_k = concat_k.input[0] + present_k = concat_k.output[0] + else: + logger.debug("fuse_rotary_attention: failed to match k nodes") + return + + # q_nodes_1 is for LLaMA-2 Microsoft + # q_nodes_2 is for LLaMA-2 Hugging Face + q_nodes = None + q_nodes_1 = self.model.match_parent_path( + matmul_qk, + ["Reshape", "Transpose", "RotaryEmbedding", "MatMul"], + [0, 0, 0, 0], + ) + q_nodes_2 = self.model.match_parent_path( + matmul_qk, + ["RotaryEmbedding", "Transpose", "Reshape", "MatMul"], + [0, 0, 0, 0], + ) + if q_nodes_1 is not None: + reshape_q_2, _, rotary_q, matmul_q = q_nodes_1 + q_nodes = q_nodes_1 + elif q_nodes_2 is not None: + rotary_q, _, reshape_q, matmul_q = q_nodes_2 + q_nodes = q_nodes_2 + else: + logger.debug("fuse_rotary_attention: failed to match q nodes") + return + + if matmul_q.input[0] != matmul_k.input[0] and matmul_k.input[0] != matmul_v.input[0]: + logger.debug("fuse_rotary_attention: failed to find the same root_input for q, k, v paths") + return + + root_output = "" + if qkv_nodes == qkv_nodes_1: + if not self.check_runtime_shape_paths_for_function( + reshape_qkv_2, + reshape_qkv_1, + reshape_q_2, + reshape_k_2, + reshape_v_2, + reshape_v_1, + add_qk, + matmul_q.input[0], + ): + logger.debug("fuse_rotary_attention: failed to verify runtime shape paths") + return + root_output = reshape_qkv_2.output[0] + + elif qkv_nodes in (qkv_nodes_2, qkv_nodes_3): + if not self.check_runtime_shape_paths_for_nodes( + reshape_qkv, + reshape_q, + reshape_k, + reshape_v, + matmul_q.input[0], + ): + logger.debug("fuse_rotary_attention: failed to verify runtime shape paths") + return + root_output = reshape_qkv.output[0] + + # Rename inputs of rotary_q/k so it connects with output of matmul_q/k + # Before: MatMul --> Reshape --> Transpose --> RotaryEmbedding + # After: MatMul --> RotaryEmbedding + rotary_q.input[0] = matmul_q.output[0] + rotary_k.input[0] = matmul_k.output[0] + + # Rename current output of rotary_k (present_key) so it doesn't match output of MHA (present_key) + rotary_k.output[0] = rotary_k.name + "_output_0" + + if qkv_nodes == qkv_nodes_3: + qkv_nodes = qkv_nodes[1:] + + new_node = self.create_mha_node( + matmul_q.input[0], + root_output, + rotary_q, + rotary_k, + matmul_v, + attn_mask, + add_qk_str, + past_k, + past_v, + present_k, + present_v, + ) + if new_node is None: + logger.debug("fuse_rotary_attention: failed to create multi-head attention with rotary embeddings") + return + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend(qkv_nodes[1:]) + + if v_nodes != v_nodes_4: + self.nodes_to_remove.extend(v_nodes[:-1]) + else: + nodes_to_keep = [v_nodes[0][-1]] + for temp_path in v_nodes: + self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep) + + self.nodes_to_remove.extend(qk_nodes) + + if k_nodes == k_nodes_1: + self.nodes_to_remove.extend(k_nodes[:-2]) + elif k_nodes == k_nodes_2: + self.nodes_to_remove.append(k_nodes[0]) + self.nodes_to_remove.append(k_nodes[2]) + self.nodes_to_remove.append(k_nodes[3]) + elif k_nodes == k_nodes_3: + self.nodes_to_remove.append(k_nodes[0]) + self.nodes_to_remove.append(k_nodes[1]) + self.nodes_to_remove.append(k_nodes[3]) + self.nodes_to_remove.append(k_nodes[4]) + elif k_nodes == k_nodes_4: + nodes_to_keep = [k_nodes[0][-1], k_nodes[0][-4]] + for temp_path in k_nodes: + self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep) + + if q_nodes == q_nodes_1: + self.nodes_to_remove.extend(q_nodes[:-2]) + elif q_nodes == q_nodes_2: + self.nodes_to_remove.append(q_nodes[1]) + self.nodes_to_remove.append(q_nodes[2]) + + self.prune_graph = True + + +class FusionRotaryEmbeddings(Fusion): + def __init__(self, model: OnnxModel): + self.base_name = "RotaryEmbedding" + super().__init__(model, self.base_name, [self.base_name, self.base_name + ".1", "Add"]) + + # The RotaryEmbedding function can have multiple extraneous constant outputs even though the function is supposed to produce only one output. + # This is a byproduct of a potential CSE bug when using `export_modules_as_functions` in the TorchScript exporter. + # To work around this issue, we set the extraneous constant values from the RotaryEmbedding function as initializers in the locations where they are actually used. + def reassign_extra_outputs(self, rot_emb_node: NodeProto, function: FunctionProto): + # Find extra outputs and Constant nodes attached to those outputs + extra_constants, extra_outputs = [], [] + for fn_node in function.node: + if fn_node.op_type == "Constant" and fn_node.input == [] and fn_node.output[0] in function.output: + extra_constants.append(fn_node) + output_index = list(function.output).index(fn_node.output[0]) + extra_outputs.append(rot_emb_node.output[output_index]) + + # Set extra Constant node outputs as initializers + extra_initializers = [] + for extra_constant in extra_constants: + constant_tensorproto = extra_constant.attribute[0].t + constant_tensorproto.name = self.model.create_node_name("Constant") + self.model.add_initializer(constant_tensorproto) + extra_initializers.append(constant_tensorproto.name) + + # Update references of Constant node outputs to initializer references + for extra_output, extra_initializer in zip(extra_outputs, extra_initializers): + nodes_to_update = list(filter(lambda entry: extra_output in entry.input, self.model.model.graph.node)) + for node_to_update in nodes_to_update: + OnnxModel.replace_node_input(node_to_update, extra_output, extra_initializer) + + return extra_outputs + + def create_rotary_embeddings_from_function(self, node: NodeProto): + rotary_emb_node_name = self.model.create_node_name(self.base_name) + + matmul_path = self.model.match_parent_path( + node, + ["Reshape", "MatMul"], + [0, 0], + ) + if matmul_path is not None: + reshape_node, matmul_node = matmul_path + else: + logger.debug("fuse_rotary_embeddings: failed to match MatMul") + return + + rotary_emb_inputs = [ + matmul_node.output[0], # x is of shape (B,S,D) instead of (B,S,N,H) + node.input[1], # position_ids + ] + + # Convert cos_cache and sin_cache from node attributes to model initializers + cos_cache_node = list(filter(lambda constant: constant.output[0] == node.input[2], self.model.model.graph.node)) + sin_cache_node = list(filter(lambda constant: constant.output[0] == node.input[3], self.model.model.graph.node)) + cos_cache_name, sin_cache_name = "cos_cache", "sin_cache" + + if ( + len(cos_cache_node) == 1 + and len(sin_cache_node) == 1 + and self.model.get_initializer(cos_cache_name) is None + and self.model.get_initializer(sin_cache_name) is None + ): + cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze() + sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze() + + cos_cache_tensor = helper.make_tensor( + name=cos_cache_name, + data_type=TensorProto.FLOAT, + dims=list(cos_cache.shape), + vals=cos_cache.flatten().tolist(), + ) + self.model.add_initializer(cos_cache_tensor, self.this_graph_name) + sin_cache_tensor = helper.make_tensor( + name=sin_cache_name, + data_type=TensorProto.FLOAT, + dims=list(sin_cache.shape), + vals=sin_cache.flatten().tolist(), + ) + self.model.add_initializer(sin_cache_tensor, self.this_graph_name) + + self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]]) + + rotary_emb_inputs.extend([cos_cache_name, sin_cache_name]) + + rotary_emb_outputs = node.output + if len(rotary_emb_outputs) > 1: + # Re-assign extraneous constant outputs in RotaryEmbedding functions as initializers + func = list(filter(lambda fn: fn.name == node.op_type, self.model.model.functions)) + assert len(func) == 1 + extra_outputs = self.reassign_extra_outputs(node, func[0]) + rotary_emb_outputs = list(filter(lambda output_name: output_name not in extra_outputs, rotary_emb_outputs)) + assert len(rotary_emb_outputs) == 1 + + rotary_emb_node = helper.make_node( + self.base_name, + inputs=rotary_emb_inputs, + outputs=rotary_emb_outputs, + name=rotary_emb_node_name, + interleaved=1, + ) + rotary_emb_node.domain = "com.microsoft" + + self.nodes_to_remove.append(reshape_node) + + return rotary_emb_node + + def create_rotary_embeddings_from_nodes( + self, + root_input: str, + position_ids: str, + cos_slice: str, + sin_slice: str, + output: str, + ): + rotary_emb_node_name = self.model.create_node_name(self.base_name) + + # Convert cos_cache and sin_cache from node attributes to model initializers + cos_cache_node = list(filter(lambda constant: constant.output[0] == cos_slice, self.model.model.graph.node)) + sin_cache_node = list(filter(lambda constant: constant.output[0] == sin_slice, self.model.model.graph.node)) + cos_cache_name, sin_cache_name = "cos_cache", "sin_cache" + + if ( + len(cos_cache_node) == 1 + and len(sin_cache_node) == 1 + and self.model.get_initializer(cos_cache_name) is None + and self.model.get_initializer(sin_cache_name) is None + ): + cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze() + sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze() + + # Reshape cos/sin cache from (M, H) to (M, H/2) + head_size = cos_cache.shape[1] + cos_cache = cos_cache[:, : (head_size // 2)] + sin_cache = sin_cache[:, : (head_size // 2)] + + cos_cache_tensor = helper.make_tensor( + name=cos_cache_name, + data_type=TensorProto.FLOAT, + dims=list(cos_cache.shape), + vals=cos_cache.flatten().tolist(), + ) + self.model.add_initializer(cos_cache_tensor, self.this_graph_name) + sin_cache_tensor = helper.make_tensor( + name=sin_cache_name, + data_type=TensorProto.FLOAT, + dims=list(sin_cache.shape), + vals=sin_cache.flatten().tolist(), + ) + self.model.add_initializer(sin_cache_tensor, self.this_graph_name) + + self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]]) + + rotary_emb_node = helper.make_node( + self.base_name, + inputs=[root_input, position_ids, cos_cache_name, sin_cache_name], + outputs=[output], + name=rotary_emb_node_name, + interleaved=0, + ) + rotary_emb_node.domain = "com.microsoft" + return rotary_emb_node + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + # Node is either RotaryEmbedding function or Add + if self.base_name not in node.op_type and node.op_type != "Add": + return + + # Check if node is "RotaryEmbedding nn.Module" exported as a function + # (e.g. export_modules_as_functions={RotaryEmbedding} in torch.onnx.export) + rotary_emb_node = None + if node.op_type != "Add": + # Verify that function has the correct inputs + if len(node.input) not in {4, 5} or node.input[1] not in { + "pos", + "pos_id", + "position_id", + "pos_ids", + "position_ids", + }: + logger.debug("fuse_rotary_embeddings: failed to verify inputs for RotaryEmbedding function") + return + + rotary_emb_node = self.create_rotary_embeddings_from_function(node) + if rotary_emb_node is None: + logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node") + return + + # Remove RotaryEmbedding function + self.nodes_to_remove.append(node) + + # Remove RotaryEmbedding function's shape inference stored in value_info + # The new shape will be calculated during symbolic shape inference + old_shape_infer = list( + filter(lambda node: node.name == rotary_emb_node.output[0], self.model.model.graph.value_info) + ) + assert len(old_shape_infer) == 1 + self.model.model.graph.value_info.remove(old_shape_infer[0]) + + else: + # Rotary embeddings are defined using the below functions: + # + # def rotate_half(x): + # """Rotates half the hidden dims of the input.""" + # x1 = x[..., : x.shape[-1] // 2] + # x2 = x[..., x.shape[-1] // 2 :] + # return torch.cat((-x2, x1), dim=-1) + # + # def apply_rope(x, cos, sin, position_ids): + # cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + # sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + # cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + # sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + # x_embed = (x * cos) + (rotate_half(x) * sin) + # return x_embed + + # Check paths for rotate_half(x) + rotate_half_x2_path_1 = self.model.match_parent_path( + node, + ["Mul", "Concat", "Neg", "Slice", "Transpose"], + [1, 0, 0, 0, 0], + ) + rotate_half_x2_path_2 = self.model.match_parent_path( + node, + ["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"], + [1, 0, 0, 0, 1, 0, 0, 0, 0], + ) + if rotate_half_x2_path_1 is None or rotate_half_x2_path_2 is None: + logger.debug("fuse_rotary_embeddings: failed to match x2 in rotate_half") + return + + rotate_half_x1_path_1 = self.model.match_parent_path( + node, + ["Mul", "Concat", "Slice", "Transpose"], + [1, 0, 1, 0], + ) + rotate_half_x1_path_2 = self.model.match_parent_path( + node, + ["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"], + [1, 0, 1, 2, 0, 0, 0, 0], + ) + if rotate_half_x1_path_1 is None or rotate_half_x1_path_2 is None: + logger.debug("fuse_rotary_embeddings: failed to match x1 in rotate_half") + return + + if ( + rotate_half_x1_path_1[-1].name != rotate_half_x1_path_2[-1].name + or rotate_half_x2_path_1[-1].name != rotate_half_x2_path_2[-1].name + or rotate_half_x1_path_1[-1].name != rotate_half_x2_path_1[-1].name + or rotate_half_x1_path_2[-1].name != rotate_half_x2_path_2[-1].name + ): + logger.debug("fuse_rotary_embeddings: failed to match common input in rotate_half") + return + + # Check path for x + x_path = self.model.match_parent_path( + node, + ["Mul", "Transpose"], + [0, 0], + ) + if x_path is None: + logger.debug("fuse_rotary_embeddings: failed to match x in rotate_half") + return + + # Check path for sin + sin_path, sin_cache, position_ids = None, "", "" + sin_path_1 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"], + [1, 1, 0, 0, 0, 0, 2, 0, 0], + ) + sin_path_2 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"], + [1, 1, 0, 0, 0, 0, 2, 0], + ) + sin_path_3 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"], + [1, 1, 0, 0, 2, 0, 0], + ) + sin_path_4 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"], + [1, 1, 0, 0, 2, 0], + ) + if sin_path_1 is not None: + sin_path = sin_path_1 + sin_cache = sin_path[-4].input[0] + elif sin_path_2 is not None: + sin_path = sin_path_2 + sin_cache = sin_path[-3].input[0] + elif sin_path_3 is not None: + sin_path = sin_path_3 + sin_cache = sin_path[-4].input[0] + position_ids = sin_path[2].input[1] + elif sin_path_4 is not None: + sin_path = sin_path_4 + sin_cache = sin_path[-3].input[0] + position_ids = sin_path[2].input[1] + else: + logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope") + return + + # Check path for cos + cos_path, cos_cache = None, "" + cos_path_1 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"], + [0, 1, 0, 0, 0, 0, 2, 0, 0], + ) + cos_path_2 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"], + [0, 1, 0, 0, 0, 0, 2, 0], + ) + cos_path_3 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"], + [0, 1, 0, 0, 2, 0, 0], + ) + cos_path_4 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"], + [0, 1, 0, 0, 2, 0], + ) + if cos_path_1 is not None: + cos_path = cos_path_1 + cos_cache = cos_path[-4].input[0] + elif cos_path_2 is not None: + cos_path = cos_path_2 + cos_cache = cos_path[-3].input[0] + elif cos_path_3 is not None: + cos_path = cos_path_3 + cos_cache = cos_path[-4].input[0] + position_ids = cos_path[2].input[1] + elif cos_path_4 is not None: + cos_path = cos_path_4 + cos_cache = cos_path[-3].input[0] + position_ids = cos_path[2].input[1] + else: + logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope") + return + + # Check path for position ids + if position_ids == "": + position_ids_from_sin_path = self.model.match_parent_path( + sin_path[2], + ["Reshape"], + [1], + ) + position_ids_from_cos_path = self.model.match_parent_path( + cos_path[2], + ["Reshape"], + [1], + ) + if ( + position_ids_from_sin_path is None + or position_ids_from_cos_path is None + or position_ids_from_sin_path[0].name != position_ids_from_cos_path[0].name + ): + logger.debug("fuse_rotary_embeddings: failed to match position ids path in apply_rope") + return + position_ids = position_ids_from_cos_path[0].input[0] + else: + position_ids_from_sin_path = [] + position_ids_from_cos_path = [] + + past_seq_len_path, curr_seq_len_path = None, None + if (sin_path == sin_path_1 and cos_path == cos_path_1) or ( + sin_path == sin_path_3 and cos_path == cos_path_3 + ): + if sin_path[-2].name != cos_path[-2].name or sin_path[-1].name != cos_path[-1].name: + logger.debug( + "fuse_rotary_embeddings: failed to match common Gather node and Shape node in sin cache and cos cache" + ) + return + elif (sin_path == sin_path_2 and cos_path == cos_path_2) or ( + sin_path == sin_path_4 and cos_path == cos_path_4 + ): + if sin_path[-1].name != cos_path[-1].name: + logger.debug("fuse_rotary_embeddings: failed to match common Add node in sin cache and cos cache") + return + # Match past sequence length path: past_key --> Shape --> Gather --> Add + past_seq_len_path = self.model.match_parent_path( + sin_path[-1], + ["Gather", "Shape"], + [1, 0], + ) + # Match current sequence length path: transpose_k --> Shape --> Gather --> Add + curr_seq_len_path = self.model.match_parent_path( + sin_path[-1], + ["Gather", "Shape", "Transpose"], + [0, 0, 0], + ) + if ( + past_seq_len_path is None + or curr_seq_len_path is None + or self.model.find_graph_input(past_seq_len_path[-1].input[0]) is None + or curr_seq_len_path[-1].op_type != "Transpose" + ): + logger.debug("fuse_rotary_embeddings: failed to match past_seq_len and curr_seq_len paths") + return + else: + logger.debug("fuse_rotary_embeddings: failed to match common cache paths") + + rotary_emb_node = self.create_rotary_embeddings_from_nodes( + rotate_half_x1_path_1[-1].output[0], + position_ids, + cos_cache, + sin_cache, + node.output[0], + ) + if rotary_emb_node is None: + logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node") + return + + # Remove rotary embedding nodes + self.add_nodes_to_remove([node]) + self.add_nodes_to_remove(rotate_half_x1_path_1[:-1]) + self.add_nodes_to_remove(rotate_half_x1_path_2[:-1]) + self.add_nodes_to_remove(rotate_half_x2_path_1[:-1]) + self.add_nodes_to_remove(rotate_half_x2_path_2[:-1]) + self.add_nodes_to_remove(x_path[:-1]) + self.add_nodes_to_remove(sin_path) + self.add_nodes_to_remove(cos_path) + self.add_nodes_to_remove(position_ids_from_sin_path[:-1]) + self.add_nodes_to_remove(position_ids_from_cos_path[:-1]) + + if past_seq_len_path is not None and len(self.model.get_children(past_seq_len_path[0])) == 1: + # In merged HF model, output of Gather in past_seq_len_path is used twice + # for past_key_values.0.key and once for other past_key_values + self.add_nodes_to_remove(past_seq_len_path) + if curr_seq_len_path is not None: + self.add_nodes_to_remove(curr_seq_len_path[:-1]) + + self.increase_counter(self.base_name) + self.node_name_to_graph_name[rotary_emb_node.name] = self.this_graph_name + self.nodes_to_add.append(rotary_emb_node) + self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/fusion_shape.py b/onnxruntime/python/tools/transformers/fusion_shape.py index 11d6b7a8d3..bc32d78eda 100644 --- a/onnxruntime/python/tools/transformers/fusion_shape.py +++ b/onnxruntime/python/tools/transformers/fusion_shape.py @@ -48,22 +48,22 @@ def fuse( input_name_to_nodes: Dict[str, List[NodeProto]], output_name_to_node: Dict[str, NodeProto], ): - """ - Smplify subgraph like - - (2d_input) - / \ - Shape shape - / \ - Gather(indices=0) Gather(indices=1) - | | - Unsqueeze(axes=0) Unsqueeze(axes=0) - \\ / - Concat - | - - into (2d_input) --> Shape --> - """ + # + # Simplify subgraph like + # + # (2d_input) + # / \ + # Shape shape + # / \ + # Gather(indices=0) Gather(indices=1) + # | | + # Unsqueeze(axes=0) Unsqueeze(axes=0) + # \ / + # Concat + # | + # + # into (2d_input) --> Shape --> + # opset_version = self.model.get_opset_version() inputs = len(concat_node.input) diff --git a/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py b/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py new file mode 100644 index 0000000000..6f35fa5617 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py @@ -0,0 +1,141 @@ +import logging +from typing import Dict + +from fusion_base import Fusion +from fusion_skiplayernorm import FusionSkipLayerNormalization +from onnx import helper +from onnx_model import OnnxModel + +logger = logging.getLogger(__name__) + + +class FusionSimplifiedLayerNormalization(Fusion): + def __init__(self, model: OnnxModel): + super().__init__(model, "SimplifiedLayerNormalization", "Mul") + + def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): + if node.op_type != "Mul": + return + + sim_ln_nodes = None + # SimplifiedLayerNorm calculation (notation from https://onnx.ai/onnx/operators/onnx__LayerNormalization.html#summary): + # DD = Pow(D, 2) + # Var = ReduceMean(DD) + # VarEps = Add(Var, epsilon) + # StdDev = Sqrt(VarEps) + # InvStdDev = Div(1, StdDev) + # Normalized = Mul(D, InvStdDev) + # NormalizedScaled = Mul(Normalized, Scale) + + # SimplifiedLayerNorm + # +-------------------------------------------------------+ + # | | + # Add --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul + # | + # node + sim_ln_nodes_1 = self.model.match_parent_path( + node, + ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"], + [1, 1, 1, 0, 0, 0, 0], + ) + # SimplifiedLayerNorm + # +-------------------------------------------------------+ + # | | + # Gather --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul + # | + # node + sim_ln_nodes_2 = self.model.match_parent_path( + node, + ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Gather"], + [1, 1, 1, 0, 0, 0, 0], + ) + + # For LLaMA from Microsoft custom export: + # sim_ln_nodes_3 uses a different start parent index than sim_ln_nodes_1 + # + # SimplifiedLayerNorm + # +-------------------------------------------------------+ + # | | + # Add --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul + # | + # node + sim_ln_nodes_3 = self.model.match_parent_path( + node, + ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"], + [0, 1, 1, 0, 0, 0, 0], + ) + + # sim_ln_nodes_4 starts with a graph input instead of an Add node like sim_ln_nodes_3 + # + # SimplifiedLayerNorm + # +-----------------------------------------------+ + # | | + # graph_input --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul + # | + # node + sim_ln_nodes_4 = self.model.match_parent_path( + node, + ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow"], + [0, 1, 1, 0, 0, 0], + ) + + add_node, pow_node = None, None + if sim_ln_nodes_1 is not None: + sim_ln_nodes = sim_ln_nodes_1 + add_node = sim_ln_nodes[3] + pow_node = sim_ln_nodes[-2] + elif sim_ln_nodes_2 is not None: + sim_ln_nodes = sim_ln_nodes_2 + add_node = sim_ln_nodes[3] + pow_node = sim_ln_nodes[-2] + elif sim_ln_nodes_3 is not None: + sim_ln_nodes = sim_ln_nodes_3 + add_node = sim_ln_nodes[3] + pow_node = sim_ln_nodes[-2] + elif sim_ln_nodes_4 is not None: + sim_ln_nodes = sim_ln_nodes_4 + add_node = sim_ln_nodes[3] + pow_node = sim_ln_nodes[-1] + # Verify that parent input to Pow node is graph_input + if pow_node.input[0] not in self.model.get_graphs_input_names(): + return + else: + return + + layernorm_weight_index = 1 if sim_ln_nodes in (sim_ln_nodes_3, sim_ln_nodes_4) else 0 + starts_with_graph_input = sim_ln_nodes == sim_ln_nodes_4 + + if self.model.find_constant_input(pow_node, 2.0) != 1: + return + + root_input = pow_node.input[0] + if root_input != sim_ln_nodes[0].input[0]: + return + + i, add_weight = self.model.get_constant_input(add_node) + if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4: + logger.warning(f"epsilon value is not expected: {add_weight}") + return + + self.nodes_to_remove.extend(sim_ln_nodes[:-1] if not starts_with_graph_input else sim_ln_nodes) + self.nodes_to_remove.append(node) + + normalize_node = helper.make_node( + "SimplifiedLayerNormalization", + inputs=[root_input, node.input[layernorm_weight_index]], + outputs=[node.output[0]], + name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="LayerNorm"), + ) + normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))]) + normalize_node.attribute.extend([helper.make_attribute("axis", -1)]) + normalize_node.attribute.extend([helper.make_attribute("stash_type", 1)]) + self.nodes_to_add.append(normalize_node) + self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name + + +class FusionSkipSimplifiedLayerNormalization(FusionSkipLayerNormalization): + def __init__(self, model: OnnxModel): + super().__init__(model, "SkipSimplifiedLayerNormalization", "SimplifiedLayerNormalization") + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + super().fuse(node, input_name_to_nodes, output_name_to_node) diff --git a/onnxruntime/python/tools/transformers/fusion_skip_group_norm.py b/onnxruntime/python/tools/transformers/fusion_skip_group_norm.py new file mode 100644 index 0000000000..df80acbd97 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_skip_group_norm.py @@ -0,0 +1,255 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from logging import getLogger +from typing import List + +from fusion_base import Fusion +from fusion_utils import NumpyHelper +from onnx import helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class FusionSkipGroupNorm(Fusion): + """ + Fuse Add + GroupNorm into one node: SkipGroupNorm. + """ + + def __init__(self, model: OnnxModel): + super().__init__(model, "SkipGroupNorm", "GroupNorm") + # Update shape inference is needed since other fusions might add new edge which does not have shape info yet. + self.shape_infer_helper = self.model.infer_runtime_shape(update=True) + + if self.shape_infer_helper is None: + logger.warning("SkipGroupNorm fusion will be skipped since symbolic shape inference disabled or failed.") + + def create_transpose_node(self, input_name: str, perm: List[int], output_name=None): + """Append a Transpose node after an input""" + node_name = self.model.create_node_name("Transpose") + if output_name is None: + output_name = node_name + "_out" + "-" + input_name + transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name) + transpose_node.attribute.extend([helper.make_attribute("perm", perm)]) + return transpose_node + + def get_skip_index(self, add, is_channel_last: bool): + """Add has two inputs. This classifies which input is skip based on shape info (skip allows broadcast).""" + skip = -1 + broadcast = False + + assert self.shape_infer_helper is not None + shape_a = self.shape_infer_helper.get_edge_shape(add.input[0]) + shape_b = self.shape_infer_helper.get_edge_shape(add.input[1]) + assert shape_a is not None and shape_b is not None + + if len(shape_a) == 4 and len(shape_b) == 4: + if shape_a == shape_b: + skip = 1 + else: + c = 3 if is_channel_last else 1 + h = 1 if is_channel_last else 2 + w = 2 if is_channel_last else 3 + if shape_a[0] == shape_b[0] and shape_a[c] == shape_b[c]: + if shape_b[h] == 1 and shape_b[w] == 1: + skip = 1 + broadcast = True + elif shape_a[h] == 1 and shape_a[w] == 1: + skip = 0 + broadcast = True + + if skip < 0: + logger.debug( + "skip SkipGroupNorm fusion since shape of Add inputs (%s, %s) are not expected", + add.input[0], + add.input[1], + ) + return skip, broadcast + + def has_multiple_consumers(self, output_name, input_name_to_nodes): + """Whether an output has multiple consumers (like graph output or more than one children nodes)""" + return self.model.find_graph_output(output_name) is not None or ( + output_name in input_name_to_nodes and len(input_name_to_nodes[output_name]) > 1 + ) + + def remove_if_safe(self, node, input_name_to_nodes): + """Remove a node if it is safe (only one children, and not graph output)""" + if not self.has_multiple_consumers(node.output[0], input_name_to_nodes): + self.nodes_to_remove.extend([node]) + + def is_bias_1d(self, bias_name: str): + """Whether bias is an initializer of one dimension""" + initializer = self.model.get_initializer(bias_name) + if initializer is None: + return False + + bias_weight = NumpyHelper.to_array(initializer) + if bias_weight is None: + logger.debug("Bias weight not found") + return False + + if len(bias_weight.shape) != 1: + logger.debug("Bias weight is not 1D") + return False + return True + + def match_bias_path(self, node, input_name_to_nodes, output_name_to_node): + """ + Match the bias graph pattern from an Transpose node after Reshape node like in below example. + It checks whether the bias is 1D initializer. If so, remove Add and redirect MatMul output to Reshape. + """ + # Before Fusion: + # MatMul (bias) + # \ / (shape) + # Add / + # \ / + # (a) Reshape + # \ | + # Transpose([0, 3, 1, 2]) Transpose([0, 3, 1, 2]) --- the start node, this func only handles the above nodes. + # \ / + # Add + # / \ + # (c) Transpose([0,2,3,1]) + # | + # GroupNorm + # | + # (d) + # + # After Fusion (the nodes below Reshape is handled in the fuse function): + # MatMul (shape) + # \ / + # (a) Reshape + # \ / + # SkipGroupNorm + # / \ + # (d) Transpose([0, 3, 1, 2]) + # \ + # (c) + + add_input_index = [] + bias_nodes = self.model.match_parent_path( + node, ["Reshape", "Add", "MatMul"], [0, 0, None], output_name_to_node, add_input_index + ) + if bias_nodes is None: + return None + + (reshape, add_bias, matmul) = bias_nodes + bias = bias_nodes[1].input[1 - add_input_index[0]] + if not self.is_bias_1d(bias): + return None + + reshape.input[0] = matmul.output[0] + self.remove_if_safe(add_bias, input_name_to_nodes) + + return bias + + def match_transpose_from_nhwc(self, output_name, input_name_to_nodes, output_name_to_node): + """Match whether an output is from a Transpose(perm=[0,3,1,2]) node.""" + parent = output_name_to_node[output_name] if output_name in output_name_to_node else None + if parent is not None and parent.op_type == "Transpose": + permutation = OnnxModel.get_node_attribute(parent, "perm") + if permutation == [0, 3, 1, 2]: + self.remove_if_safe(parent, input_name_to_nodes) + return parent + return None + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + # This fusion requires shape information, so skip it if shape is not available. + if self.shape_infer_helper is None: + return + + # Before Fusion: + # (a) (b) + # \ / + # Add + # /\ + # (c) Transpose([0,2,3,1]) + # \ + # GroupNorm + # | + # (d) + # + # After Fusion: + # (a) (b) + # \ / + # Transpose([0,2,3,1]) Transpose([0,2,3,1]) + # \ / + # SkipGroupNorm + # / \ + # / Transpose([0, 3, 1, 2]) + # / \ + # (d) (c) + nodes = self.model.match_parent_path(node, ["Transpose", "Add"], [0, 0], output_name_to_node) + if nodes is None: + return + + (transpose, add) = nodes + if transpose in self.nodes_to_remove or add in self.nodes_to_remove: + return + + if self.has_multiple_consumers(transpose.output[0], input_name_to_nodes): + return + + permutation = OnnxModel.get_node_attribute(transpose, "perm") + if permutation != [0, 2, 3, 1]: + return + + inputs = [] + bias = None + for i in range(2): + matched_transpose = self.match_transpose_from_nhwc(add.input[i], input_name_to_nodes, output_name_to_node) + if matched_transpose: + # When there is an Transpose node before Add (see examples in match_bias_path), we do not need to + # insert another Transpose node. The existing Transpose node will be removed in prune_graph if it + # has only one consumer. + inputs.append(matched_transpose.input[0]) + # See whether it match bias pattern. + if bias is None: + bias = self.match_bias_path(matched_transpose, input_name_to_nodes, output_name_to_node) + else: + # Otherwise, insert a Transpose node before Add. + new_transpose = self.create_transpose_node(add.input[i], [0, 2, 3, 1]) + self.model.add_node(new_transpose, self.this_graph_name) + inputs.append(new_transpose.output[0]) + + skip, broadcast = self.get_skip_index(add, is_channel_last=False) + if skip < 0: + return + + inputs = [inputs[1 - skip], node.input[1], node.input[2], inputs[skip]] + if bias: + inputs = [*inputs, bias] + + outputs = node.output + + new_node_name = self.model.create_node_name(self.fused_op_type, name_prefix="SkipGroupNorm") + if self.has_multiple_consumers(add.output[0], input_name_to_nodes): + add_out_name = new_node_name + "_add_out" + outputs.append(add_out_name) + + # Insert a Transpose node after add output. + add_out_transpose = self.create_transpose_node(add_out_name, [0, 3, 1, 2], add.output[0]) + self.model.add_node(add_out_transpose, self.this_graph_name) + + skip_group_norm = helper.make_node( + self.fused_op_type, + inputs=inputs, + outputs=outputs, + name=new_node_name, + ) + skip_group_norm.domain = "com.microsoft" + + self.increase_counter( + f"SkipGroupNorm(add_out={int(len(outputs) > 1)} bias={int(bias is not None)} broadcast={int(broadcast)})" + ) + + # Pass attributes from GroupNorm node to SkipGroupNorm + for att in node.attribute: + skip_group_norm.attribute.extend([att]) + + self.nodes_to_remove.extend([add, transpose, node]) + self.nodes_to_add.append(skip_group_norm) + self.node_name_to_graph_name[skip_group_norm.name] = self.this_graph_name + self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py index 4b771c5bee..1ec5edf686 100644 --- a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py @@ -38,17 +38,17 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): # In some models there is input_ids->gather->add->LayerNorm and one of input of the # add node is initializer with fixed shape which should not be fused into SkipLayerNorm - if add is None: + if add is None or add.op_type != "Add": + return + + # The number of inputs of add should be 2 + if len(add.input) != 2: return for add_input in add.input: if self.model.get_initializer(add_input) is not None: return - # The number of input node of add should be 2 - if len(self.model.get_parents(add)) != 2: - return - # To avoid an Add node have two children of LayerNormalization, we shall only fuse one SkipLayerNormalization if add in self.nodes_to_remove: return @@ -57,6 +57,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): simplified = node.op_type == "SimplifiedLayerNormalization" if self.shape_infer_helper is not None: + # TODO(tianleiwu): support broadcasting Skip shape (1, sequence_length, hidden_size) or (sequence_length, hidden_size) if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]): logger.debug( "skip SkipLayerNormalization fusion since shape of inputs (%s, %s) are not same", @@ -73,15 +74,14 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): if self.model.match_parent_path(gather_path[0], ["ConstantOfShape"], [1]) is None: return - residual_add_has_multiple_consumers = False - add_children = self.model.get_children(add, input_name_to_nodes) - # This means that the residual Add before the LayerNormalization produces an output - # that is consumed by some other nodes other than the LayerNormalization itself + # that is consumed by some other nodes or graph output other than the LayerNormalization itself # We can still go ahead with the SkipLayerNormalization fusion but we need to # preserve the output of Add and that needs to be produced by SkipLayerNormalization. - if len(add_children) != 1: - residual_add_has_multiple_consumers = True + add_has_graph_output = self.model.find_graph_output(add.output[0]) is not None + residual_add_has_multiple_consumers = ( + add_has_graph_output or len(self.model.get_children(add, input_name_to_nodes)) > 1 + ) outputs_to_keep = node.output @@ -94,11 +94,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): if residual_add_has_multiple_consumers: outputs.extend(["", "", add.output[0]]) - if ( - add is not None - and add.op_type == "Add" - and self.model.is_safe_to_fuse_nodes([add, node], outputs_to_keep, input_name_to_nodes, output_name_to_node) - ): + if self.model.is_safe_to_fuse_nodes([add, node], outputs_to_keep, input_name_to_nodes, output_name_to_node): self.nodes_to_remove.extend([add, node]) inputs = ( @@ -136,32 +132,33 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): return return_indice = [] - nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [None, None], None, return_indice) - if nodes is None: + nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [None, None], output_name_to_node, return_indice) + if nodes is not None: + (add, _matmul) = nodes + else: # In case of fp16, we could have a Cast between the MatMul and the bias Add + return_indice = [] nodes = self.model.match_parent_path( - node, ["Add", "Cast", "MatMul"], [None, None, None], None, return_indice + node, ["Add", "Cast", "MatMul"], [None, None, None], output_name_to_node, return_indice ) - if nodes is None: + if nodes is not None: + (add, _cast, _matmul) = nodes + else: return assert len(return_indice) == 2 or len(return_indice) == 3 add_input_index = return_indice[0] if add_input_index >= 2: return - - (add, matmul) = nodes + sln_input = add.input[return_indice[1]] + bias_input = add.input[1 - return_indice[1]] + skip_input = node.input[1 - add_input_index] # bias should be one dimension - bias_index = -1 - bias_weight = None - for i, input in enumerate(add.input): - initializer = self.model.get_initializer(input) - if initializer is None: - continue - bias_index = i - bias_weight = NumpyHelper.to_array(initializer) - break + initializer = self.model.get_initializer(bias_input) + if initializer is None: + return + bias_weight = NumpyHelper.to_array(initializer) if bias_weight is None: logger.debug("Bias weight not found") return @@ -176,11 +173,11 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): self.nodes_to_remove.extend(subgraph_nodes) inputs = [ - node.input[1 - add_input_index], - matmul.output[0], + sln_input, + skip_input, node.input[2], node.input[3], - add.input[bias_index], + bias_input, ] new_node = helper.make_node( "SkipLayerNormalization", diff --git a/onnxruntime/python/tools/transformers/fusion_transpose.py b/onnxruntime/python/tools/transformers/fusion_transpose.py index 6602d16830..2762d95dd7 100644 --- a/onnxruntime/python/tools/transformers/fusion_transpose.py +++ b/onnxruntime/python/tools/transformers/fusion_transpose.py @@ -139,23 +139,23 @@ def fuse( # Here we use hard-coded name so that it could be shared for the whole model. axes_1 = "ort_const_unsqueeze_axes_1" if self.model.get_initializer(axes_1) is None: - axes_1_tensor = helper.make_tensor( + self.add_initializer( name=axes_1, data_type=TensorProto.INT64, dims=[1], vals=[1], + raw=False, ) - self.model.add_initializer(axes_1_tensor, self.this_graph_name) axes_2 = "ort_const_unsqueeze_axes_2" if self.model.get_initializer(axes_2) is None: - axes_2_tensor = helper.make_tensor( + self.add_initializer( name=axes_2, data_type=TensorProto.INT64, dims=[1], vals=[2], + raw=False, ) - self.model.add_initializer(axes_2_tensor, self.this_graph_name) unsqueeze_3.input[1] = "ort_const_unsqueeze_axes_2" unsqueeze_2.input[1] = "ort_const_unsqueeze_axes_1" diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py index 0715395268..50703b9c17 100644 --- a/onnxruntime/python/tools/transformers/io_binding_helper.py +++ b/onnxruntime/python/tools/transformers/io_binding_helper.py @@ -1,5 +1,6 @@ import logging -from typing import Dict, List +from collections import OrderedDict +from typing import Any, Dict, List, Tuple, Union import numpy import torch @@ -205,3 +206,113 @@ def get_outputs_from_io_binding_buffer(ort_session, output_buffers, output_shape else: ort_outputs.append(copy_tensor) return ort_outputs + + +class CudaSession: + """Inference Session with IO Binding for ONNX Runtime CUDA or TensorRT provider""" + + def __init__(self, ort_session: InferenceSession, device: torch.device, enable_cuda_graph=False): + self.ort_session = ort_session + self.input_names = [input.name for input in self.ort_session.get_inputs()] + self.output_names = [output.name for output in self.ort_session.get_outputs()] + self.io_name_to_numpy_type = TypeHelper.get_io_numpy_type_map(self.ort_session) + self.io_binding = self.ort_session.io_binding() + self.enable_cuda_graph = enable_cuda_graph + + self.input_tensors = OrderedDict() + self.output_tensors = OrderedDict() + self.device = device + + def __del__(self): + del self.input_tensors + del self.output_tensors + del self.io_binding + del self.ort_session + + def allocate_buffers(self, shape_dict: Dict[str, Union[Tuple[int], List[int]]]): + """Allocate tensors for I/O Binding""" + if self.enable_cuda_graph: + for name, shape in shape_dict.items(): + if name in self.input_names: + # Reuse allocated buffer when the shape is same + if name in self.input_tensors: + if tuple(self.input_tensors[name].shape) == tuple(shape): + continue + raise RuntimeError("Expect static input shape for cuda graph") + + numpy_dtype = self.io_name_to_numpy_type[name] + tensor = torch.empty(tuple(shape), dtype=TypeHelper.numpy_type_to_torch_type(numpy_dtype)).to( + device=self.device + ) + self.input_tensors[name] = tensor + + self.io_binding.bind_input( + name, + tensor.device.type, + tensor.device.index, + numpy_dtype, + list(tensor.size()), + tensor.data_ptr(), + ) + + for name, shape in shape_dict.items(): + if name in self.output_names: + # Reuse allocated buffer when the shape is same + if name in self.output_tensors and tuple(self.output_tensors[name].shape) == tuple(shape): + continue + + numpy_dtype = self.io_name_to_numpy_type[name] + tensor = torch.empty(tuple(shape), dtype=TypeHelper.numpy_type_to_torch_type(numpy_dtype)).to( + device=self.device + ) + self.output_tensors[name] = tensor + + self.io_binding.bind_output( + name, + tensor.device.type, + tensor.device.index, + numpy_dtype, + list(tensor.size()), + tensor.data_ptr(), + ) + + def infer(self, feed_dict: Dict[str, torch.Tensor]): + """Bind input tensors and run inference""" + for name, tensor in feed_dict.items(): + assert isinstance(tensor, torch.Tensor) and tensor.is_contiguous() + if name in self.input_names: + if self.enable_cuda_graph: + assert self.input_tensors[name].nelement() == tensor.nelement() + assert self.input_tensors[name].dtype == tensor.dtype + assert tensor.device.type == "cuda" + # Please install cuda-python package with a version corresponding to CUDA in your machine. + from cuda import cudart + + # Update input tensor inplace since cuda graph requires input and output has fixed memory address. + cudart.cudaMemcpy( + self.input_tensors[name].data_ptr(), + tensor.data_ptr(), + tensor.element_size() * tensor.nelement(), + cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice, + ) + else: + self.io_binding.bind_input( + name, + tensor.device.type, + tensor.device.index, + TypeHelper.torch_type_to_numpy_type(tensor.dtype), + [1] if len(tensor.shape) == 0 else list(tensor.shape), + tensor.data_ptr(), + ) + + self.ort_session.run_with_iobinding(self.io_binding) + + return self.output_tensors + + @staticmethod + def get_cuda_provider_options(device_id: int, enable_cuda_graph: bool) -> Dict[str, Any]: + return { + "device_id": device_id, + "arena_extend_strategy": "kSameAsRequested", + "enable_cuda_graph": enable_cuda_graph, + } diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py new file mode 100644 index 0000000000..3b344d6dc9 --- /dev/null +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -0,0 +1,385 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +""" +Export LLM to onnx +""" +import argparse +import inspect +import math +import os +import tempfile +from pathlib import Path +from typing import Optional + +import onnx +import torch +import transformers +from torch import nn + + +def disable_huggingface_init(): + """do not init model twice as it slow initialization""" + + torch.nn.init.kaiming_uniform_ = lambda x, *args, **kwargs: x + torch.nn.init.uniform_ = lambda x, *args, **kwargs: x + torch.nn.init.normal_ = lambda x, *args, **kwargs: x + torch.nn.init.constant_ = lambda x, *args, **kwargs: x + torch.nn.init.xavier_uniform_ = lambda x, *args, **kwargs: x + torch.nn.init.xavier_normal_ = lambda x, *args, **kwargs: x + torch.nn.init.kaiming_normal_ = lambda x, *args, **kwargs: x + torch.nn.init.orthogonal_ = lambda x, *args, **kwargs: x + + +def get_model_parameter_size(model: nn.Module): + """to calculate how much memory this model needs""" + param_size = 0 + param_sum = 0 + for param in model.parameters(): + param_size += param.nelement() * param.element_size() + param_sum += param.nelement() + buffer_size = 0 + buffer_sum = 0 + for buffer in model.buffers(): + buffer_size += buffer.nelement() * buffer.element_size() + buffer_sum += buffer.nelement() + all_size = (param_size + buffer_size) / 1024 / 1024 + return all_size + + +def initialize_model_and_sample_inputs(hf_model: str, cache_dir: Optional[str], tokenizer=None): + """ + get the pretrained torch model from hugginface, + and sample model-inputs + """ + + disable_huggingface_init() + + model = transformers.AutoModelForCausalLM.from_pretrained( # type: ignore + hf_model, torch_dtype=torch.float16, cache_dir=cache_dir, trust_remote_code=True + ) + if tokenizer is None: + tokenizer = hf_model + tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) # type: ignore + + sample_inputs = tuple(tokenizer("Hello, my dog is cute", return_tensors="pt").values()) + return model, sample_inputs + + +def auto_pipeline_parallel(model: nn.Module, gpulist: list, sample_inputs: tuple): + """Make the model executable across multiple GPUs.""" + + def input_gpu_device_hook(mod, inputs, kwargs): + modifyed_inputs = [] + first_dev = None + for layer_input in inputs: + if type(layer_input) is not torch.Tensor: + modifyed_inputs.append(layer_input) + elif hasattr(mod, "weight"): + modifyed_inputs.append(layer_input.to(mod.weight.device)) + elif hasattr(mod, "parameters"): + device = next(mod.parameters(), layer_input).device + modifyed_inputs.append(layer_input.to(device)) + elif hasattr(next(mod.children(), None), "weight"): + modifyed_inputs.append(layer_input.to(next(mod.children()).weight.device)) + elif first_dev is not None and layer_input.device != first_dev: + modifyed_inputs.append(layer_input.to(first_dev)) + else: + modifyed_inputs.append(layer_input) + if first_dev is None: + first_dev = modifyed_inputs[0].device + for key, value in kwargs.items(): + if type(value) is torch.Tensor: + kwargs[key] = value.to(first_dev) + + return (tuple(modifyed_inputs), kwargs) + + def move_layer_to_device_rurc(mod, dev): + mod.to(dev) + for layer in mod.named_children(): + move_layer_to_device_rurc(layer[1], dev) + + model = model.half() + all_hooks = [] + all_hooks.append(model.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) + pre_fix = next(iter(model.named_children()))[0] + for top_name, top_module in model.named_children(): + for name, module in top_module.named_children(): + all_hooks.append(module.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) + if type(module) in [torch.nn.ModuleList]: + num_layers_on_each_gpu = math.floor(len(module) / len(gpulist)) + for idx, attn_layer in enumerate(module): + all_hooks.append(attn_layer.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) + + to_dev = gpulist[min(idx // num_layers_on_each_gpu, len(gpulist))] + attn_layer.to(to_dev) + move_layer_to_device_rurc(attn_layer, to_dev) + print(f"move {pre_fix}.{name}.{idx} to {to_dev}") + else: + module.to(gpulist[0]) + print(f"move {pre_fix}.{name} to {gpulist[0]}") + if len(list(top_module.named_children())) == 0: + top_module.to(gpulist[0]) + print(f"move {top_name} to {gpulist[0]}") + + with torch.no_grad(): + model(sample_inputs[0], attention_mask=sample_inputs[1]) + return model + + +def retrieve_onnx_inputs(model: nn.Module, sample_inputs: tuple, with_past: bool): + """ + auto retrieve onnx inputs from torch model as we can't enumlate all possibilities + for all models + """ + user_inputs = [] + + def hook_for_inputs(_, inputs, kwargs): + user_inputs.append((inputs, kwargs)) + return user_inputs[0] + + hook_handle = model.register_forward_pre_hook(hook_for_inputs, with_kwargs=True) + + forward_params = inspect.signature(model.forward).parameters + input_keys = list(forward_params.keys()) + default_values = [forward_params.get(key).default for key in input_keys] + out = model(sample_inputs[0], attention_mask=sample_inputs[1]) + hook_handle.remove() + user_inputs = user_inputs[0] + onnx_inputs = default_values + for idx, _val in enumerate(user_inputs[0]): + onnx_inputs[idx] = user_inputs[0][idx] + for key, value in user_inputs[1].items(): + idx = input_keys.index(key) + onnx_inputs[idx] = value + for idx, (key, value) in enumerate(zip(input_keys, onnx_inputs)): + if type(value) is torch.Tensor: + value.to(model.device) + # Didn't touch past_key_value now, please change it if you want + if "use_cache" in key: + onnx_inputs[idx] = with_past + + return input_keys, onnx_inputs, out.past_key_values + + +def move_to_approprate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module: + """ + According to the model size, we will upload it to + CPU if has no GPU or enough GPU memory, + Single GPU if has only one GPU in local or model size is enough to fit one GPU + Multiple GPU if there is more than one gpu in local and model is too large + """ + total_mem_per_cpu = torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 + + print(f"Model_Size = {get_model_parameter_size(model)/1024} GB") + print(f"total_mem_per_cpu = {total_mem_per_cpu/1024} GB") + if get_model_parameter_size(model) > total_mem_per_cpu * 0.45: + device_collection = [torch.device(i) for i in range(torch.cuda.device_count())] + if len(device_collection) > 1: + print( + f"{len(device_collection)} GPUs are used to export onnx, \ + Please set CUDA_VISIBLE_DEVICES to use specific GPU group" + ) + model = auto_pipeline_parallel(model, device_collection, sample_inputs_tp) + else: + print("!!!! convert model to float and export onnx using CPU") + model = model.cpu().float() + else: + print("Export model on a single GPU") + model = model.cuda().half() + return model + + +def adapt_inputs_to_device(sample_inputs: tuple, device: torch.device) -> tuple: + """move inputs to device""" + sample_inputs_ = [] + for sample_int in sample_inputs: + if isinstance(sample_int, torch.Tensor): + sample_inputs_.append(sample_int.to(device)) + else: + sample_inputs_.append(sample_int) + return tuple(sample_inputs_) + + +def fetch_onnx_inputs_outputs_name( + model: nn.Module, + onnx_inputs: list, + torch_input_names: tuple, + past_key_values: tuple, + with_past: bool, + input_with_past: bool, +): + """fetch onnx inputs and outputs name""" + num_of_past_key = 0 + kv_cache_axis = {0: "batch_size"} + # try get num_of_past_key and shape of past_key_value + if past_key_values is not None: + num_of_past_key = len(past_key_values) + seq_index = (torch.tensor(past_key_values[0][0].shape) == onnx_inputs[0].shape[-1]).nonzero().view(-1) + assert seq_index.numel() == 1 + kv_cache_axis = {0: "batch_size", seq_index.item(): "seq_len"} + + if not num_of_past_key: + num_of_past_key = model.config.num_hidden_layers + + onnx_inp_names = ("input_ids", "attention_mask") + onnx_out_names = ("logits",) + onnx_dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "attention_mask": {0: "batch_size", 1: "seq_len"}, + } + if input_with_past: + for i in range(num_of_past_key): + onnx_inp_names += (f"present_key.{i}",) + onnx_inp_names += (f"present_values.{i}",) + + onnx_dynamic_axes[onnx_inp_names[-1]] = kv_cache_axis + onnx_dynamic_axes[onnx_inp_names[-2]] = kv_cache_axis + + if with_past or input_with_past: + for i in range(num_of_past_key): + onnx_out_names += (f"past_key.{i}",) + onnx_out_names += (f"past_values.{i}",) + onnx_dynamic_axes[onnx_out_names[-1]] = kv_cache_axis + onnx_dynamic_axes[onnx_out_names[-2]] = kv_cache_axis + + for idx, name in enumerate(torch_input_names): + if input_with_past: + if name == "past_key_values": + onnx_inputs[idx] = past_key_values + elif name == "attention_mask": + attn_mask = onnx_inputs[idx] + onnx_inputs[idx] = torch.cat( + (attn_mask, torch.ones((attn_mask.shape[0], 1), device=attn_mask.device)), dim=1 + ) + elif name == "input_ids": + input_ids = onnx_inputs[idx] + onnx_inputs[idx] = input_ids[:, -1:] + + return onnx_inp_names, onnx_out_names, onnx_dynamic_axes + + +def do_export_internal(model: nn.Module, onnx_io_tuple: tuple, onnx_inputs: tuple, onnx_path: Path, opset: int): + """do export with torch.onnx.export""" + onnx_model_name = onnx_path.name + onnx_inp_names, onnx_out_names, onnx_dynamic_axes = onnx_io_tuple + # two step to export onnx + # 1. export onnx with lots of pieces of weights + # 2. save all weights to external data + with tempfile.TemporaryDirectory() as tmpdirname: + tmp_onnx = os.path.join(tmpdirname, "tmp.onnx") + + torch.onnx.export( + model=model, + args=tuple(onnx_inputs), + f=tmp_onnx, + verbose=False, + opset_version=opset, + input_names=onnx_inp_names, + output_names=onnx_out_names, + dynamic_axes=onnx_dynamic_axes, + ) + + onnx_path.unlink(missing_ok=True) + (onnx_path.parent / f"{onnx_model_name}_ext.data").unlink(missing_ok=True) + + onnx_model = onnx.load(str(tmp_onnx)) + onnx.save_model( + onnx_model, + str(onnx_path), + save_as_external_data=(len(os.listdir(tmpdirname)) > 1), + all_tensors_to_one_file=True, + location=f"{onnx_model_name}_ext.data", + size_threshold=1024, + convert_attribute=False, + ) + + +@torch.no_grad() +def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, with_past: bool, opset: int): + """ + do export + model: torch model + onnx_path: where the onnx model saved to + sample_inputs_tp: inputs for torch model + """ + model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model, cache_dir) + + model = move_to_approprate_device(model, sample_inputs_tp) + + sample_inputs = adapt_inputs_to_device(sample_inputs_tp, next(model.parameters()).device) + + # input_keys would be usesful if the model has some special inputs + input_keys, onnx_inputs, past_key_value = retrieve_onnx_inputs(model, sample_inputs, with_past) + + onnx_io_tuple = fetch_onnx_inputs_outputs_name(model, onnx_inputs, input_keys, past_key_value, with_past, False) + + onnx_model_name = "model.onnx" + onnx_path: Path = Path(onnx_path_str).absolute() + if onnx_path.suffix != ".onnx": + onnx_path = onnx_path / onnx_model_name + + do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path, opset) + if not with_past: + return + + onnx_io_tuple = fetch_onnx_inputs_outputs_name(model, onnx_inputs, input_keys, past_key_value, with_past, True) + + onnx_model_name = "model_with_past.onnx" + onnx_path = onnx_path.parent / onnx_model_name + + do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path, opset) + + +def parse_arguments(): + """arguments parsing.""" + parser = argparse.ArgumentParser() + + parser.add_argument( + "-m", + "--model", + required=True, + type=str, + default=["meta-llama/Llama-2-70b-hf"], + help="Pre-trained models in huggingface model hub", + ) + parser.add_argument( + "-s", + "--saved_path", + required=False, + type=str, + default="./onnx_models/", + help="where the onnx model will be saved", + ) + parser.add_argument( + "--cache_dir", + required=False, + type=str, + default=None, + help=("cache directy of huggingface, by setting this to avoid useless downloading if you have one"), + ) + parser.add_argument( + "--with_past", + action="store_true", + default=False, + help=("The tool will export onnx without past-key-value by default"), + ) + parser.add_argument( + "--opset", + required=False, + type=int, + default=17, + help=( + "the opset to save onnx model, \ + try to increase it if this opset doens't have new features you want" + ), + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_arguments() + + export_onnx(args.model, args.cache_dir, args.saved_path, args.with_past, args.opset) diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md new file mode 100644 index 0000000000..0c6f830ed2 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -0,0 +1,362 @@ +# LLaMA-2 + +## Prerequisites + +Please note the package versions needed for using LLaMA-2 in the `requirements.txt` file that fits your scenario. +- `requirements-cpu.txt` + - For running LLaMA-2 on CPU +- `requirements-cuda.txt` + - For running LLaMA-2 on CUDA + - Note that `torch` with CUDA enabled is not installed automatically. This is because `torch` should be installed with the CUDA version used on your machine. Please visit [the PyTorch website](https://pytorch.org/get-started/locally/) to download the `torch` version that is used with the CUDA version installed on your machine and satisfies the requirement listed in the file. +- `requirements-quant.txt` + - For running the SmoothQuant algorithm using [Intel's Neural Compressor](https://github.com/intel/neural-compressor) +- `requirements-70b-model.txt` + - For running the LLaMA-2 70B model on multiple GPUs +- `requirements.txt` + - Package versions needed in each of the above files + +## Exporting LLaMA-2 + +There are several ways to export LLaMA-2 models (using LLaMA-2 7B as an example). + +### Option 1: from convert_to_onnx +``` +# From source: +$ git clone https://github.com/microsoft/onnxruntime +$ cd onnxruntime/onnxruntime/python/tools/transformers/ +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b + +# From wheel: +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b +``` + +To make this option compatible with [Hugging Face's Optimum](https://github.com/huggingface/optimum), you will need to create `config.json` and `generation_config.json` for your model and store them in the same directory as your ONNX models. For example, you can find those JSON files for LLaMA-2 7B on Hugging Face [here](https://huggingface.co/meta-llama/Llama-2-7b-hf). + +As indicated in `requirements.txt`, you will also need to install Optimum from source. Once installed, you will need to modify `ORTModelForCausalLM.forward` in `optimum/optimum/onnxruntime/modeling_decoder.py` as follows: + +``` +# Before +if self.use_cache: + if past_key_values is not None: + input_ids = input_ids[:, -1:] + # Flatten the past_key_values (no need to flatten for models using multi-query attn) + + +# After +if self.use_cache: + if past_key_values is not None: + input_ids = input_ids[:, -1:] if past_key_values[0][0].shape[2] != 0 else input_ids + # Flatten the past_key_values (no need to flatten for models using multi-query attn) +``` + +### Option 2: from [Microsoft's custom export](https://github.com/microsoft/Llama-2-Onnx) + +Please follow the [README instructions](https://github.com/microsoft/Llama-2-Onnx#before-you-start) in the custom export of LLaMA-2. + +### Option 3: from [Hugging Face Optimum](https://github.com/huggingface/optimum) + +Note that this may produce two ONNX models with older Optimum versions. The above two options produce one ONNX model and installing Optimum from source will now produce one ONNX model. + +First, log into the Hugging Face CLI in your terminal: + +``` +$ huggingface-cli login +``` + +Once authenticated, run the following Python code to export: + +``` +from optimum.onnxruntime import ORTModelForCausalLM + +name = "meta-llama/Llama-2-7b-hf" +model = ORTModelForCausalLM.from_pretrained( + name, + export=True, + use_auth_token=True, +) +model.save_pretrained(name.split("/")[-1] + "-onnx") +``` + +## Examples of Exporting LLaMA-2 + +Here are some additional examples for exporting LLaMA-2. + +Export Model with Different GPU Device Ids +``` +# From source using first GPU: +$ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b + +# From wheel using second GPU: +$ CUDA_VISIBLE_DEVICES=1 python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b +``` + +Export Saved Model on Disk +``` +# From source: +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b + +# From wheel: +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b +``` + +Export for FP32 CUDA +``` +# From source: +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-gpu --precision fp32 --execution_provider cuda + +# From wheel: +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-gpu --precision fp32 --execution_provider cuda +``` + +Export for FP32 CPU +``` +# From source: +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-cpu --precision fp32 --execution_provider cpu + +# From wheel: +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-cpu --precision fp32 --execution_provider cpu +``` + +Export for FP16 CUDA (with MultiHeadAttention) +``` +# From source: +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda + +# From wheel: +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda +``` + +Export for FP16 CUDA (with GroupQueryAttention) +``` +# From source: +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda --use_gqa + +# From wheel: +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda --use_gqa +``` + +Note: GroupQueryAttention currently runs on Linux for FP16 CUDA and INT4 CUDA models, and it can provide faster inference than MultiHeadAttention, especially for large sequence lengths (e.g. 1024 or larger). For the best performance, you should pre-allocate the KV cache buffers to have size `(batch_size, num_heads, max_sequence_length, head_size)` so that the past KV and present KV caches share the same memory. You also need to bind them with ONNX Runtime's [IO binding](https://onnxruntime.ai/docs/api/python/api_summary.html#iobinding). + +Here is an example of how you can bind directly to `torch.tensor` objects: +``` +# Assumes all inputs and outputs to the model are pre-allocated with the correct shapes in GPU memory + +# Bind inputs +for k, v in inputs.items(): + io_binding.bind_input( + name=k, + device_type="cuda", + device_id=0, + element_type=np.float16, + shape=tuple(v.shape), + buffer_ptr=v.data_ptr() + ) + +# Bind outputs +for output in model.get_outputs(): + name = output.name + if "present" in name: + # Bind KV cache outputs to KV cache inputs + v = inputs[name.replace("present", "past_key_values")] + io_binding.bind_output( + name=name, + device_type="cuda", + device_id=0, + element_type=np.float16, + shape=tuple(v.shape), + buffer_ptr=v.data_ptr() + ) + else: + # Bind other outputs as actual outputs + v = outputs[name] + io_binding.bind_output( + name=name, + device_type="cuda", + device_id=0, + element_type=np.float16, + shape=tuple(v.shape), + buffer_ptr=v.data_ptr() + ) + +io_binding.synchronize_inputs() +sess.run_with_iobinding(io_binding) +io_binding.synchronize_outputs() +``` + +Export for INT8 CPU (SmoothQuant) +``` +# From source: +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu --no_merged + +# From wheel: +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu --no_merged +``` + +Note: [Intel's Neural Compressor](https://github.com/intel/neural-compressor) takes time to run the SmoothQuant quantization algorithm on LLMs. On an [Azure Standard_NC24s_v3 VM](https://learn.microsoft.com/en-us/azure/virtual-machines/ncv3-series), it takes about ~30-45 min for each of the exported ONNX models. + +Export for INT8 CPU (DynamicQuant) +``` +# From source: +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method quantize_dynamic --execution_provider cpu + +# From wheel: +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method quantize_dynamic --execution_provider cpu +``` + +Export for INT4 CUDA +``` +# From source: +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-gpu --precision int4 --quantization_method blockwise --execution_provider cuda --use_gqa + +# From wheel: +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-gpu --precision int4 --quantization_method blockwise --execution_provider cuda --use_gqa +``` + +Note: See the FP16 CUDA notes about GroupQueryAttention. The `--use_gqa` flag is optional. + +Export for INT4 CPU +``` +# From source: +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-cpu --precision int4 --quantization_method blockwise --execution_provider cpu + +# From wheel: +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-cpu --precision int4 --quantization_method blockwise --execution_provider cpu +``` + +Export LLaMA-2 70B sharded model into 4 partitions +``` +# From source: +# 1. Install necessary packages from requirements-70b-model.txt +$ pip install -r requirements-70b-model.txt + +# 2. Build ONNX Runtime from source with NCCL enabled. Here is a sample command: +$ ./build.sh --config Release --use_cuda --cuda_home /usr/local/cuda-12.2 --cudnn_home /usr/local/cuda-12.2 --build_wheel --cuda_version=12.2 --parallel --skip_tests --enable_nccl --nccl_home /usr/local/cuda-12.2 --use_mpi --mpi_home=/usr/lib/x86_64-linux-gnu/ + +# 3. Shard and export the LLaMA-2 70B model. With FP16, you will need at least 140GB of GPU memory to load the model. Therefore, you will need at least 4 40GB A100 GPUs or 2 80GB A100 GPUs to shard the PyTorch model and export each shard to ONNX. Here is an example command: +$ CUDA_VISIBLE_DEVICES=0,1,2,3 bash convert_70b_model.sh 4 -m meta-llama/Llama-2-70b-hf --output llama2-70b-distributed --precision fp16 --execution_provider cuda --use_gqa +``` + +## Benchmark LLaMA-2 + +Here are some examples of how you can benchmark LLaMA-2. + +### Variants + +1. PyTorch without `torch.compile`, FP32 +``` +python3 -m models.llama.benchmark \ + --benchmark-type hf-pt-eager \ + --model-name meta-llama/Llama-2-7b-hf \ + --precision fp32 \ + --batch-sizes "1 2" \ + --sequence-lengths "8 16" \ + --device cpu \ + --auth +``` + +2. PyTorch with `torch.compile`, FP16 +``` +python3 -m models.llama.benchmark \ + --benchmark-type hf-pt-compile \ + --model-name meta-llama/Llama-2-7b-hf \ + --precision fp16 \ + --batch-sizes "1 2" \ + --sequence-lengths "8 16" \ + --device cuda \ + --auth +``` + +3. Optimum + ONNX Runtime, FP32, export via Optimum or convert_to_onnx +``` +python3 -m models.llama.benchmark \ + --benchmark-type hf-ort \ + --hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \ + --model-name meta-llama/Llama-2-7b-hf \ + --precision fp32 \ + --batch-sizes "1 2" \ + --sequence-lengths "8 16" \ + --device cpu \ + --auth +``` + +4. Optimum + ONNX Runtime, FP16, export via Optimum or convert_to_onnx +``` +python3 -m models.llama.benchmark \ + --benchmark-type hf-ort \ + --hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \ + --model-name meta-llama/Llama-2-7b-hf \ + --precision fp16 \ + --batch-sizes "1 2" \ + --sequence-lengths "8 16" \ + --device cuda \ + --auth +``` + +5. ONNX Runtime, FP32, Microsoft custom export +``` +python3 -m models.llama.benchmark \ + --benchmark-type ort-msft \ + --ort-model-path ./llama-2-onnx/7B_float32/ONNX/LlamaV2_7B_float32.onnx \ + --model-name meta-llama/Llama-2-7b-hf \ + --precision fp32 \ + --batch-sizes "1 2" \ + --sequence-lengths "8 16" \ + --device cpu +``` + +6. ONNX Runtime, FP16, Microsoft custom export +``` +python3 -m models.llama.benchmark \ + --benchmark-type ort-msft \ + --ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \ + --model-name meta-llama/Llama-2-7b-hf \ + --precision fp16 \ + --batch-sizes "1 2" \ + --sequence-lengths "8 16" \ + --device cuda +``` + +7. ONNX Runtime, FP32, convert_to_onnx, use 2nd GPU +``` +CUDA_VISIBLE_DEVICES=1 python3 -m models.llama.benchmark \ + --benchmark-type ort-convert-to-onnx \ + --ort-model-path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32.onnx \ + --model-name meta-llama/Llama-2-7b-hf \ + --precision fp32 \ + --batch-sizes "1 2" \ + --sequence-lengths "8 16" \ + --device cpu +``` + +8. ONNX Runtime, FP16, convert_to_onnx, use 5th GPU +``` +CUDA_VISIBLE_DEVICES=4 python3 -m models.llama.benchmark \ + --benchmark-type ort-convert-to-onnx \ + --ort-model-path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp16.onnx \ + --model-name meta-llama/Llama-2-7b-hf \ + --precision fp16 \ + --batch-sizes "1 2" \ + --sequence-lengths "8 16" \ + --device cuda +``` + +You can profile a variant by adding the `--profile` flag and providing one batch size and sequence length combination. + +### Benchmark All +You can use `benchmark_all.py` to benchmark across various options and automatically store the results in a CSV file. Here is an example. +``` +python3 -m models.llama.benchmark_all \ + --hf-pt-eager \ + --hf-pt-compile \ + --hf-ort-dir-path ./llama2-7b-fp16/ \ + --ort-convert-to-onnx-model-path ./llama2-7b-fp16/Llama-2-7b-hf_decoder_merged_model_fp16.onnx \ + --ort-msft-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \ + --model-name meta-llama/Llama-2-7b-hf \ + --precision fp16 \ + --batch-sizes "1 2" \ + --sequence-lengths "8 16" \ + --device cuda \ + --warmup-runs 5 \ + --num-runs 1000 \ + --timeout 60 # number of minutes before moving to the next benchmark +``` diff --git a/onnxruntime/python/tools/transformers/models/llama/__init__.py b/onnxruntime/python/tools/transformers/models/llama/__init__.py new file mode 100644 index 0000000000..e80f36a391 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/__init__.py @@ -0,0 +1,12 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import os +import sys + +sys.path.append(os.path.dirname(__file__)) + +transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..")) +if transformers_dir not in sys.path: + sys.path.append(transformers_dir) diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py new file mode 100644 index 0000000000..021b0dd03a --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py @@ -0,0 +1,695 @@ +import argparse +import datetime +import gc +import itertools +import logging +import os +import sys +import time + +import numpy as np +import onnx +import psutil +import torch +from benchmark_helper import measure_memory, setup_logger +from dist_settings import get_rank, get_size +from llama_inputs import ( + add_io_bindings, + get_merged_sample_with_past_kv_inputs, + get_msft_sample_inputs, + get_sample_inputs, + get_sample_with_past_kv_inputs, +) +from optimum.onnxruntime import ORTModelForCausalLM +from torch.profiler import ProfilerActivity, profile, record_function +from tqdm import trange +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +import onnxruntime as ort + +logger = logging.getLogger(__name__) + + +# For determining whether the ONNX model can do both prompt generation and token generation or only one of the two +def get_ort_model_inputs_len(args, model): + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: + return 0 + if args.benchmark_type == "hf-ort": + try: + # New Optimum export (https://github.com/huggingface/optimum/blob/888332364c2e0091da1fc974737c7e277af168bf/optimum/onnxruntime/modeling_ort.py#L268) + return len(model.inputs_names) + except Exception: + # Old Optimum export (https://github.com/huggingface/optimum/blob/c5ad7f971cb0a494eac03dc0909f146725f999c5/optimum/onnxruntime/base.py#L54) + return len(model.decoder.input_names) + return len(model.get_inputs()) + + +def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): + init_inputs, iter_inputs = None, None + + # For past_present_share_buffer: + # Set max_seq_len to 16384 for CodeLLaMA (finetuned variant of LLaMA-2) + # Set max_seq_len to 4096 for Hugging Face LLaMA-2 model since that is the default value + # Set max_seq_len to 2048 for Microsoft LLaMA-2 model since that is the max value currently supported + temp_name = args.model_name.lower().replace("-", "").replace("_", "") + max_seq_len = ( + 2048 + if args.benchmark_type == "ort-msft" + else 16384 + if "codellama" in temp_name + else 4096 + if "llama2" in temp_name + else 2048 + ) + + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: + init_inputs = get_sample_inputs( + args.config, + args.target_device, + args.batch_size, + args.sequence_length, + return_dict=True, + ) + iter_inputs = get_sample_with_past_kv_inputs( + args.config, + args.target_device, + args.batch_size, + args.sequence_length, + use_fp16=args.use_fp16, + return_dict=True, + ) + + elif args.benchmark_type == "hf-ort": + if ort_model_inputs_len == 3: # [input_ids, attention_mask, position_ids] + # Using split models in Optimum (e.g. created by Optimum export) + init_inputs = get_sample_inputs( + args.config, + args.target_device, + args.batch_size, + args.sequence_length, + return_dict=True, + ) + iter_inputs = get_sample_with_past_kv_inputs( + args.config, + args.target_device, + args.batch_size, + args.sequence_length, + use_fp16=args.use_fp16, + return_dict=True, + ) + else: + # Using merged model in Optimum (e.g. created by convert_to_onnx export) + init_inputs = get_merged_sample_with_past_kv_inputs( + args.config, + args.target_device, + args.batch_size, + seq_len=args.sequence_length, + past_seq_len=0, + max_seq_len=max_seq_len, + use_fp16=args.use_fp16, + use_gqa=args.use_gqa, + engine="pt", + return_dict=True, + ) + iter_inputs = get_merged_sample_with_past_kv_inputs( + args.config, + args.target_device, + args.batch_size, + seq_len=1, + past_seq_len=args.sequence_length, + max_seq_len=max_seq_len, + use_fp16=args.use_fp16, + use_gqa=args.use_gqa, + engine="pt", + return_dict=True, + ) + + elif args.benchmark_type == "ort-convert-to-onnx": + # Microsoft export from convert_to_onnx + init_inputs = get_merged_sample_with_past_kv_inputs( + args.config, + args.target_device, + args.batch_size, + seq_len=args.sequence_length, + past_seq_len=0, + max_seq_len=max_seq_len, + use_fp16=args.use_fp16, + use_gqa=args.use_gqa, + engine="ort", + return_dict=True, + world_size=args.world_size, + ) + iter_inputs = get_merged_sample_with_past_kv_inputs( + args.config, + args.target_device, + args.batch_size, + seq_len=1, + past_seq_len=args.sequence_length, + max_seq_len=max_seq_len, + use_fp16=args.use_fp16, + use_gqa=args.use_gqa, + engine="ort", + return_dict=True, + world_size=args.world_size, + ) + + elif args.benchmark_type == "ort-msft": + # Microsoft export from https://github.com/microsoft/Llama-2-Onnx + split_kv = ort_model_inputs_len > 5 # original inputs: [x, attn_mask, k_cache, v_cache, pos] + + init_inputs = get_msft_sample_inputs( + args.config, + args.batch_size, + past_seq_len=0, + seq_len=args.sequence_length, + max_seq_len=max_seq_len, + use_fp16=args.use_fp16, + use_gqa=args.use_gqa, + split_kv=split_kv, + ) + iter_inputs = get_msft_sample_inputs( + args.config, + args.batch_size, + past_seq_len=args.sequence_length, + seq_len=1, + max_seq_len=max_seq_len, + use_fp16=args.use_fp16, + use_gqa=args.use_gqa, + split_kv=split_kv, + ) + + else: + raise Exception("Unable to auto-detect inputs for provided model") + + return init_inputs, iter_inputs + + +def get_model(args: argparse.Namespace): + model, sess_options = None, None + start_time, end_time = None, None + + # There are multiple sources that the model could come from: + # 1) Benchmark LLaMA-2 from unofficial source on Hugging Face + # 2) Benchmark LLaMA-2 from official source on Hugging Face, which requires an authentication token + # 3) Benchmark LLaMA-2 from local download of model + # 4) Benchmark LLaMA-2 from Microsoft (already optimized, available at https://github.com/microsoft/Llama-2-Onnx) + # 5) Benchmark LLaMA-2 from convert_to_onnx + + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: + source = args.hf_pt_dir_path if args.hf_pt_dir_path else args.model_name + start_time = time.time() + model = AutoModelForCausalLM.from_pretrained( + source, + torch_dtype=torch.float16 if args.use_fp16 else torch.float32, + use_auth_token=args.auth, + use_cache=True, + ).to(args.target_device) + end_time = time.time() + + if args.benchmark_type == "hf-pt-compile": + model = torch.compile(model) + + elif args.benchmark_type in {"hf-ort", "ort-msft", "ort-convert-to-onnx"}: + sess_options = ort.SessionOptions() + sess_options.enable_profiling = args.profile + if args.verbose: + sess_options.log_verbosity_level = 1 + sess_options.log_severity_level = 1 + + else: + raise Exception(f"Cannot recognize {args.benchmark_type}") + + if args.benchmark_type == "hf-ort": + # Optimum export or convert_to_onnx.py export + provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider + provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None + + decoder_file_name = None + decoder_with_past_file_name = None + for filename in os.listdir(args.hf_ort_dir_path): + if ".onnx" not in filename or ".onnx_data" in filename or ".onnx.data" in filename: + continue + if "decoder_model" in filename or filename == "model.onnx": + decoder_file_name = filename + if "decoder_with_past_model" in filename: + decoder_with_past_file_name = filename + if "decoder_merged_model" in filename: + decoder_file_name = filename + decoder_with_past_file_name = filename + + start_time = time.time() + model = ORTModelForCausalLM.from_pretrained( + args.hf_ort_dir_path, + decoder_file_name=decoder_file_name, + decoder_with_past_file_name=decoder_with_past_file_name, + use_auth_token=args.auth, + use_io_binding=(args.device != "cpu"), + use_merged=(True if decoder_file_name == "model.onnx" else None), + provider=provider, + provider_options=provider_options, + session_options=sess_options, + ) + end_time = time.time() + + if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}: + # Ex: Microsoft export from https://github.com/microsoft/Llama-2-Onnx + logger.info(f"Loading model from {args.ort_model_path.format(args.rank)}") + start_time = time.time() + model = ort.InferenceSession( + args.ort_model_path.format(args.rank), + sess_options, + providers=[args.execution_provider], + ) + end_time = time.time() + + logger.info(f"Loaded model in {end_time - start_time} s") + return model + + +def time_fn(args, fn, inputs): + # Warm up + warmup_range = ( + range(args.warmup_runs) + if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} + else trange(args.warmup_runs, file=sys.stdout, desc="Warm up") + ) + + if args.verbose: + outputs = fn(inputs) + logger.info(outputs) + + input_sync = ( # noqa: E731 + lambda *kwargs: args.io_binding.synchronize_inputs() + if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize + else lambda *kwargs: torch.cuda.synchronize() + if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize + else lambda *kwargs: None # no-op function + ) + + output_sync = ( # noqa: E731 + lambda *kwargs: args.io_binding.synchronize_outputs() + if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize + else lambda *kwargs: torch.cuda.synchronize() + if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize + else lambda *kwargs: None # no-op function + ) + + for _ in warmup_range: + input_sync() + fn(inputs) + output_sync() + + # Benchmark + total_time = 0 + bench_range = ( + range(args.num_runs) + if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} + else trange(args.num_runs, file=sys.stdout, desc="Benchmark") + ) + for _ in bench_range: + input_sync() + start_time = time.time() + + fn(inputs) + + output_sync() + end_time = time.time() + + total_time += end_time - start_time + + # Newline print after trange in order to print metrics on new lines without progress bar on same line + if args.benchmark_type not in {"ort-msft", "ort-convert-to-onnx"}: + logger.info("") + + latency = total_time / args.num_runs + throughput = args.batch_size / latency + + if args.rank == 0: + logger.info(f"Batch Size: {args.batch_size}") + logger.info(f"Sequence Length: {args.sequence_length}") + logger.info(f"Latency: {latency} s") + logger.info(f"Throughput: {throughput} tps") + return + + +def profile_fn(args, fn, inputs, inputs_type): + # Filename prefix format: + # "b_s_--___" + prefix = f"b{args.batch_size}_s{args.sequence_length}_{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}" + filename = None + + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: + # Profile PyTorch kernels + with profile( # noqa: SIM117 + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True + ) as prof: + with record_function("model_inference"): + fn(inputs) + prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows) + + filename = os.path.join(args.log_folder, f"{prefix}.log") + with open(filename, "w") as f: + f.write(prof_data) + + else: + # Profile ORT kernels + fn(inputs) + + # Set new log name for ORT profile log generated + filename = f"{prefix}.json" + + return filename + + +def measure_fn(args, fn, inputs): + # Measure CPU usage + pid = os.getpid() + process = psutil.Process(pid) + process.cpu_percent(interval=0.1) + + fn(inputs) + if args.rank == 0: + logger.info(f"CPU usage: {process.cpu_percent(interval=None) / psutil.cpu_count(logical=False)}%") + + # Measure memory usage + gc.collect() + torch.cuda.empty_cache() + measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs)) + + # Flush output so memory usage is printed + sys.stdout.flush() + + +def run_hf_inference(args, init_inputs, iter_inputs, model): + # Inference steps to measure + def get_logits(inputs): + # Inference pass without decoding + outputs = model(**inputs) + return outputs + + # Examples of other inference steps that can be measured: + # To use, uncomment the function and assign it to `generate_fn` + + # def get_pred_ids(inputs): + # # Inference pass with predicted token ids generation + # predicted_ids = model.generate(**inputs) + # return predicted_ids + + # def gen_and_dec(inputs): + # # Inference pass with generation and decoding + # predicted_ids = get_pred_ids(inputs) + # transcription = [] + # for bs in range(args.batch_size): + # for rs in range(args.num_return_sequences): + # transcription.append( + # args.tokenizer.batch_decode( + # predicted_ids[bs * args.num_return_sequences + rs], skip_special_tokens=True + # )[0] + # ) + # return transcription + + generate_fn = get_logits + + if args.benchmark_type == "hf-pt-compile": + # Run forward pass once with each set of inputs to process through Dynamo + generate_fn(init_inputs) + generate_fn(iter_inputs) + + if args.profile: + new_logname = profile_fn(args, generate_fn, init_inputs, "prompt") + if args.benchmark_type == "hf-ort": + # Turn profiling off to stop appending to log + old_logname = model.decoder.session.end_profiling() + logger.warning(f"Renaming {old_logname} to {new_logname}") + os.rename(old_logname, os.path.join(args.log_folder, new_logname)) + + new_logname = profile_fn(args, generate_fn, iter_inputs, "token") + if args.benchmark_type == "hf-ort": + # Turn profiling off to stop appending to log + old_logname = model.decoder_with_past.session.end_profiling() + logger.warning(f"Renaming {old_logname} to {new_logname}") + os.rename(old_logname, os.path.join(args.log_folder, new_logname)) + + return + + # PyTorch evaluations + logger.info("\nEvaluating `model(inputs)` step to get past_key_values") + time_fn(args, generate_fn, init_inputs) + measure_fn(args, generate_fn, init_inputs) + + logger.info("\nEvaluating `model(inputs)` step with past_key_values") + time_fn(args, generate_fn, iter_inputs) + measure_fn(args, generate_fn, iter_inputs) + + +def run_ort_inference(args, init_inputs, iter_inputs, model): + def prepare_ort_inputs(inputs, kv_cache_ortvalues): + # Check that all model inputs will be provided + model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs())) + user_inputs = set(inputs.keys()) + missing_inputs = model_inputs - user_inputs + if len(missing_inputs): + logger.error(f"The following model inputs are missing: {missing_inputs}") + raise Exception("There are missing inputs to the model. Please add them and try again.") + + # Remove unnecessary inputs from model inputs + unnecessary_inputs = user_inputs - model_inputs + if len(unnecessary_inputs): + for unnecessary_input in unnecessary_inputs: + logger.info(f"Removing unnecessary input '{unnecessary_input}' from user provided inputs") + del inputs[unnecessary_input] + + # Add IO bindings for non-CPU execution providers + if args.device != "cpu": + io_binding, kv_cache_ortvalues = add_io_bindings( + model, inputs, args.device, int(args.rank), args.use_gqa, kv_cache_ortvalues + ) + setattr(args, "io_binding", io_binding) # noqa: B010 + return io_binding, kv_cache_ortvalues + + return inputs, kv_cache_ortvalues + + def with_io_binding(io_binding): + # Inference pass with IO binding + model.run_with_iobinding(io_binding) + + def without_io_binding(inputs): + # Inference pass without IO binding + outputs = model.run(None, inputs) + return outputs + + generate_fn = with_io_binding if args.device != "cpu" else without_io_binding + kv_cache_ortvalues = {} + + if args.profile: + ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues) + new_logname = profile_fn(args, generate_fn, ort_init_inputs, "prompt") + + # Turn profiling off to stop appending to log file + old_logname = model.end_profiling() + logger.warning(f"Renaming {old_logname} to {new_logname}") + os.rename(old_logname, os.path.join(args.log_folder, new_logname)) + + # Re-initialize model for new log file instead of appending to old log file + model = get_model(args) + ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues) + new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token") + + # Turn profiling off to stop appending to log + old_logname = model.end_profiling() + logger.warning(f"Renaming {old_logname} to {new_logname}") + os.rename(old_logname, os.path.join(args.log_folder, new_logname)) + return + + # ORT evaluations + logger.info("\nEvaluating `model(inputs)` step to get past_key_values") + ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues) + time_fn(args, generate_fn, ort_init_inputs) + measure_fn(args, generate_fn, ort_init_inputs) + + logger.info("\nEvaluating `model(inputs)` step with past_key_values") + ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues) + time_fn(args, generate_fn, ort_iter_inputs) + measure_fn(args, generate_fn, ort_iter_inputs) + + +def run_inference(args, init_inputs, iter_inputs, model): + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}: + run_hf_inference(args, init_inputs, iter_inputs, model) + elif args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}: + run_ort_inference(args, init_inputs, iter_inputs, model) + else: + raise Exception(f"Cannot recognize {args.benchmark_type}") + + +def get_args(rank=0): + parser = argparse.ArgumentParser() + parser.add_argument( + "-bt", + "--benchmark-type", + type=str, + required=True, + choices=["hf-pt-eager", "hf-pt-compile", "hf-ort", "ort-msft", "ort-convert-to-onnx"], + ) + parser.add_argument( + "-m", + "--model-name", + type=str, + required=True, + help="Hugging Face name of model (e.g. 'meta-llama/Llama-2-7b-hf')", + ) + parser.add_argument( + "-a", "--auth", default=False, action="store_true", help="Use Hugging Face authentication token to access model" + ) + + # Args for choosing the model + parser.add_argument( + "-p", + "--precision", + required=True, + type=str, + default="fp32", + choices=["int4", "int8", "fp16", "fp32"], + help="Precision for model. For ONNX models, the model's precision should be set before running this script.", + ) + parser.add_argument( + "--hf-pt-dir-path", + type=str, + default="", + help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)", + ) + parser.add_argument( + "--hf-ort-dir-path", + type=str, + default="", + help="Path to directory containing all ONNX files (e.g. tokenizer, decoder_merged, decoder, decoder_with_past)", + ) + parser.add_argument( + "--ort-model-path", + type=str, + default="", + help="Path to ONNX model", + ) + + # Args for running and evaluating the model + parser.add_argument( + "-b", + "--batch-sizes", + default="1 2", + ) + parser.add_argument( + "-s", + "--sequence-lengths", + default="32 64 128 256 512", + ) + parser.add_argument( + "-d", + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + choices=["cpu", "cuda", "rocm"], + ) + parser.add_argument("-id", "--device-id", type=int, default=0) + parser.add_argument("-w", "--warmup-runs", type=int, default=5) + parser.add_argument("-n", "--num-runs", type=int, default=10) + parser.add_argument("--seed", type=int, default=2) + + # Args for decoding logic + parser.add_argument("--max-length", type=int, default=32) + parser.add_argument("--num-return-sequences", type=int, default=1) + + # Args for accessing detailed info + parser.add_argument("--profile", default=False, action="store_true") + parser.add_argument( + "--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by" + ) + parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display") + parser.add_argument("--verbose", default=False, action="store_true") + parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files") + + args = parser.parse_args() + + # Set seed properties + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + # Set runtime properties + if "ort" in args.benchmark_type: + setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010 + if args.execution_provider == "CUDAExecutionProvider": + args.execution_provider = (args.execution_provider, {"device_id": rank}) + elif args.execution_provider == "ROCMExecutionProvider": + args.execution_provider = (args.execution_provider, {"device_id": rank}) + args.device = "cuda" + + # Check that paths have been specified for any benchmarking with ORT + if args.benchmark_type == "hf-ort": + assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`" + if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}: + assert args.ort_model_path, "Please specify a path to `--ort-model-path`" + + args.batch_sizes = args.batch_sizes.split(" ") + args.sequence_lengths = args.sequence_lengths.split(" ") + + # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models + args.precision = ( + "fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16" + ) + + # Check that only one (batch_size, sequence_length) combination is set for profiling + if args.profile: + assert ( + len(args.batch_sizes) == 1 and len(args.sequence_lengths) == 1 + ), "Please provide only one (batch_size, sequence_length) combination for profiling" + + return args + + +def main(): + rank = get_rank() + world_size = get_size() + + args = get_args(rank) + setup_logger(args.verbose) + logger.info(args.__dict__) + torch.backends.cudnn.benchmark = True + + args.rank = rank + args.world_size = world_size + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + config = AutoConfig.from_pretrained(args.model_name) + target_device = f"cuda:{args.rank}" if args.device != "cpu" else args.device + use_fp16 = args.precision == "fp16" + + setattr(args, "tokenizer", tokenizer) # noqa: B010 + setattr(args, "config", config) # noqa: B010 + setattr(args, "target_device", target_device) # noqa: B010 + setattr(args, "use_fp16", use_fp16) # noqa: B010 + + # Get model and model info + model = get_model(args) + ort_model_inputs_len = get_ort_model_inputs_len(args, model) + + # Check if past_present_share_buffer can be enabled (only for FP16 models with GQA) + if args.benchmark_type in {"ort-convert-to-onnx", "ort-msft"}: + onnx_model = onnx.load_model(args.ort_model_path.format(args.rank), load_external_data=False) + gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node)) + + use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu" + setattr(args, "use_gqa", use_buffer_share) # noqa: B010 + else: + setattr(args, "use_gqa", False) # noqa: B010 + + # Measure prompt cost (init_inputs) and generated token cost (iter_inputs) + for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths): + if args.rank == 0: + logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...") + setattr(args, "batch_size", int(batch_size)) # noqa: B010 + setattr(args, "sequence_length", int(sequence_length)) # noqa: B010 + + init_inputs, iter_inputs = get_inputs(args, ort_model_inputs_len) + run_inference(args, init_inputs, iter_inputs, model) + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_70b_model.sh b/onnxruntime/python/tools/transformers/models/llama/benchmark_70b_model.sh new file mode 100644 index 0000000000..38f1916456 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_70b_model.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +NUM_GPUS=${1:-1} + +MPI="mpirun --allow-run-as-root + -mca btl_openib_warn_no_device_params_found 0 -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0 + --tag-output --npernode $NUM_GPUS --bind-to numa + -x MIOPEN_FIND_MODE=1" + +CMD="$MPI python benchmark.py ${@:2}" + +$CMD \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py new file mode 100644 index 0000000000..b35a5e27f9 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py @@ -0,0 +1,411 @@ +import argparse +import datetime +import json +import logging +import os +import subprocess + +import torch +from benchmark_helper import setup_logger + +logger = logging.getLogger(__name__) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-b", + "--batch-sizes", + type=str, + default="1 2", + ) + + parser.add_argument( + "-s", + "--sequence-lengths", + type=str, + default="8 16 32 64 128 256 512", + ) + + parser.add_argument( + "-w", + "--warmup-runs", + type=int, + default=5, + ) + + parser.add_argument( + "-n", + "--num-runs", + type=int, + default=1000, + ) + + parser.add_argument( + "--hf-pt-eager", + default=False, + action="store_true", + help="Benchmark in PyTorch without `torch.compile`", + ) + + parser.add_argument( + "--hf-pt-compile", + default=False, + action="store_true", + help="Benchmark in PyTorch with `torch.compile`", + ) + + parser.add_argument( + "--hf-ort-dir-path", + type=str, + default="", + help="Path to folder containing ONNX models for Optimum + ORT benchmarking", + ) + + parser.add_argument( + "--ort-msft-model-path", + type=str, + default="", + help="Path to ONNX model from https://github.com/microsoft/Llama-2-Onnx", + ) + + parser.add_argument( + "--ort-convert-to-onnx-model-path", + type=str, + default="", + help="Path to ONNX model from convert_to_onnx", + ) + + parser.add_argument( + "--model-name", + type=str, + required=True, + help="Model name in Hugging Face", + ) + + parser.add_argument( + "--precision", + type=str, + required=True, + choices=["int4", "int8", "fp16", "fp32"], + help="Precision to run model", + ) + + parser.add_argument( + "--device", + type=str, + required=True, + choices=["cpu", "cuda", "rocm"], + help="Device to benchmark models", + ) + + parser.add_argument( + "--device-id", + type=int, + default=0, + help="GPU device ID", + ) + + parser.add_argument( + "--verbose", + default=False, + action="store_true", + help="Print detailed logs", + ) + + parser.add_argument( + "--timeout", + type=int, + default=10, + help="Number of mins to attempt the benchmark before moving on", + ) + + args = parser.parse_args() + + setattr(args, "model_size", args.model_name.split("/")[-1].replace(".", "-")) # noqa: B010 + log_folder_name = f"./{args.model_size}_{args.precision}" + setattr(args, "log_folder", log_folder_name) # noqa: B010 + os.makedirs(args.log_folder, exist_ok=True) + + # Convert timeout value to secs + args.timeout *= 60 + + return args + + +def process_log_file(device_id, log_file, base_results): + entries = [] + batch_size, sequence_length, step = None, None, None + latency_s, latency_ms, throughput, memory = None, None, None, None + + batch_pattern = "Batch Size: " + sequence_pattern = "Sequence Length: " + prompt_step_pattern = "to get past_key_values" + per_token_step_pattern = "with past_key_values" + latency_pattern = "Latency: " + throughput_pattern = "Throughput: " + memory_pattern = "peak=" + + with open(log_file) as f: + for input_line in f: + line = input_line.replace("\n", "") + + if batch_pattern in line: + batch_size = int(line[len(batch_pattern) :]) + elif sequence_pattern in line: + sequence_length = int(line[len(sequence_pattern) :]) + elif prompt_step_pattern in line: + step = "prompt" + elif per_token_step_pattern in line: + step = "per-token" + elif latency_pattern in line: + latency_s = float(line[len(latency_pattern) : line.rfind(" ")]) + latency_ms = latency_s * 1000 + elif throughput_pattern in line: + throughput = float(line[len(throughput_pattern) : line.rfind(" ")]) + elif memory_pattern in line: + if "CPU" in line: + # Example format for log entry: + # CPU memory usage: before=1000.0 MB, peak=2000.0 MB + memory = float(line[line.rfind("=") + 1 : line.rfind(" MB")]) / 1000 + else: + # Example format for log entry: + # GPU memory usage: before=[{'device_id': 0, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 69637.25}, {'device_id': 1, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 890.625}] peak=[{'device_id': 0, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 73861.25}, {'device_id': 1, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 890.625}] + peak = line[line.find(memory_pattern) + len(memory_pattern) :].replace("'", '"') + usage = json.loads(peak)[device_id]["max_used_MB"] + memory = float(usage) / 1000 + + # Append log entry to list of entries + entry = base_results + [ # noqa: RUF005 + batch_size, + sequence_length, + step, + latency_s, + latency_ms, + throughput, + memory, + ] + entries.append(entry) + + return entries + + +def save_results(results, filename): + import pandas as pd + + df = pd.DataFrame( + results, + columns=[ + "Engine", + "Precision", + "Device", + "Batch Size", + "Sequence Length", + "Step", + "Latency (s)", + "Latency (ms)", + "Throughput (tps)", + "Memory (GB)", + ], + ) + + # Set column types + df["Batch Size"] = df["Batch Size"].astype("int") + df["Sequence Length"] = df["Sequence Length"].astype("int") + df["Latency (s)"] = df["Latency (s)"].astype("float") + df["Latency (ms)"] = df["Latency (ms)"].astype("float") + df["Throughput (tps)"] = df["Throughput (tps)"].astype("float") + df["Memory (GB)"] = df["Memory (GB)"].astype("float") + + df.to_csv(filename, index=False) + logger.info(f"Results saved in {filename}!") + + +def benchmark(args, benchmark_cmd, engine): + log_filename = f"{engine}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.log" + log_path = os.path.join(args.log_folder, log_filename) + with open(log_path, "w") as log_file: + process = subprocess.Popen(benchmark_cmd, stdout=log_file, stderr=log_file) + try: + process.wait(args.timeout) + except subprocess.TimeoutExpired: + process.kill() + + # Create entries for csv + logger.info("Gathering data from log files...") + base_results = [engine, args.precision, args.device] + results = process_log_file(args.device_id, log_path, base_results) + + return results + + +def main(): + args = get_args() + setup_logger(args.verbose) + logger.info(args.__dict__) + torch.backends.cudnn.benchmark = True + + all_results = [] + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id) + + # Benchmark PyTorch without torch.compile + if args.hf_pt_eager: + benchmark_cmd = [ + "python", + "-m", + "models.llama.benchmark", + "--benchmark-type", + "hf-pt-eager", + "--model-name", + args.model_name, + "--precision", + args.precision, + "--batch-sizes", + args.batch_sizes, + "--sequence-lengths", + args.sequence_lengths, + "--device", + args.device, + "--warmup-runs", + str(args.warmup_runs), + "--num-runs", + str(args.num_runs), + "--log-folder", + args.log_folder, + "--auth", + ] + logger.info("Benchmark PyTorch without torch.compile") + results = benchmark(args, benchmark_cmd, "pytorch-eager") + all_results.extend(results) + + # Benchmark PyTorch with torch.compile + if args.hf_pt_compile: + benchmark_cmd = [ + "python", + "-m", + "models.llama.benchmark", + "--benchmark-type", + "hf-pt-compile", + "--model-name", + args.model_name, + "--precision", + args.precision, + "--batch-sizes", + args.batch_sizes, + "--sequence-lengths", + args.sequence_lengths, + "--device", + args.device, + "--warmup-runs", + str(args.warmup_runs), + "--num-runs", + str(args.num_runs), + "--log-folder", + args.log_folder, + "--auth", + ] + logger.info("Benchmark PyTorch with torch.compile") + results = benchmark(args, benchmark_cmd, "pytorch-compile") + all_results.extend(results) + + # Benchmark Optimum + ONNX Runtime + if args.hf_ort_dir_path: + benchmark_cmd = [ + "python", + "-m", + "models.llama.benchmark", + "--benchmark-type", + "hf-ort", + "--hf-ort-dir-path", + args.hf_ort_dir_path, + "--model-name", + args.model_name, + "--precision", + args.precision, + "--batch-sizes", + args.batch_sizes, + "--sequence-lengths", + args.sequence_lengths, + "--device", + args.device, + "--warmup-runs", + str(args.warmup_runs), + "--num-runs", + str(args.num_runs), + "--log-folder", + args.log_folder, + "--auth", + ] + logger.info("Benchmark Optimum + ONNX Runtime") + results = benchmark(args, benchmark_cmd, "optimum-ort") + all_results.extend(results) + + # Benchmark Microsoft model in ONNX Runtime + if args.ort_msft_model_path: + benchmark_cmd = [ + "python", + "-m", + "models.llama.benchmark", + "--benchmark-type", + "ort-msft", + "--ort-model-path", + args.ort_msft_model_path, + "--model-name", + args.model_name, + "--precision", + args.precision, + "--batch-sizes", + args.batch_sizes, + "--sequence-lengths", + args.sequence_lengths, + "--device", + args.device, + "--warmup-runs", + str(args.warmup_runs), + "--num-runs", + str(args.num_runs), + "--log-folder", + args.log_folder, + ] + logger.info("Benchmark Microsoft model in ONNX Runtime") + results = benchmark(args, benchmark_cmd, "ort-msft") + all_results.extend(results) + + # Benchmark convert_to_onnx model in ONNX Runtime + if args.ort_convert_to_onnx_model_path: + benchmark_cmd = [ + "python", + "-m", + "models.llama.benchmark", + "--benchmark-type", + "ort-convert-to-onnx", + "--ort-model-path", + args.ort_convert_to_onnx_model_path, + "--model-name", + args.model_name, + "--precision", + args.precision, + "--batch-sizes", + args.batch_sizes, + "--sequence-lengths", + args.sequence_lengths, + "--device", + args.device, + "--warmup-runs", + str(args.warmup_runs), + "--num-runs", + str(args.num_runs), + "--log-folder", + args.log_folder, + ] + logger.info("Benchmark convert_to_onnx model in ONNX Runtime") + results = benchmark(args, benchmark_cmd, "onnxruntime") + all_results.extend(results) + + csv_file = f"{args.model_size}_{args.precision}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv" + save_results(all_results, os.path.join(args.log_folder, csv_file)) + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_70b_model.sh b/onnxruntime/python/tools/transformers/models/llama/convert_70b_model.sh new file mode 100644 index 0000000000..637d15c10e --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/convert_70b_model.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +NUM_GPUS=${1:-1} + +MPI="mpirun --allow-run-as-root + -mca btl_openib_warn_no_device_params_found 0 -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0 + --tag-output --npernode $NUM_GPUS --bind-to numa + -x MIOPEN_FIND_MODE=1" + +CMD="$MPI python convert_to_onnx.py ${@:2}" + +$CMD \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py new file mode 100644 index 0000000000..c9c7f4d39d --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -0,0 +1,965 @@ +import argparse +import logging +import os +import shutil +from itertools import chain +from typing import List + +import onnx +import torch +from benchmark_helper import Precision, prepare_environment, setup_logger +from convert_generation import replace_mha_with_gqa +from dist_settings import barrier, get_rank, get_size, init_dist +from llama_inputs import get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs +from llama_parity import main as parity_check +from llama_torch import setup_torch_model +from onnx_model import OnnxModel +from optimizer import optimize_model +from packaging import version +from transformers import AutoConfig, AutoModelForCausalLM + +from onnxruntime import quantization as ort_quantization +from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer + +logger = logging.getLogger("") +init_dist() + + +def get_model_dynamic_axes(input_names: List[str], output_names: List[str]): + dynamic_axes = {} + for name in input_names + output_names: + if name in input_names: + # shape is (batch_size, sequence_length) + dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"} + elif name == "logits": + # shape is (batch_size, sequence_length, vocab_size) + dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"} + elif "present" in name: + # shape is (batch_size, num_heads, sequence_length, head_size) + dynamic_axes[name] = {0: "batch_size", 2: "sequence_length"} + else: + raise Exception("Unknown input or output name found") + return dynamic_axes + + +def get_model_with_past_kv_dynamic_axes(input_names: List[str], output_names: List[str]): + dynamic_axes = {} + for name in input_names + output_names: + if name in {"input_ids", "position_ids"}: + # shape is (batch_size, 1) + dynamic_axes[name] = {0: "batch_size"} + elif name == "attention_mask": + # shape is (batch_size, past_sequence_length + 1) + dynamic_axes[name] = {0: "batch_size", 1: "past_sequence_length + 1"} + elif "past" in name: + # shape is (batch_size, num_heads, past_sequence_length, head_size) + dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length"} + elif name == "logits": + # shape is (batch_size, 1, vocab_size) + dynamic_axes[name] = {0: "batch_size"} + elif "present" in name: + # shape is (batch_size, num_heads, past_sequence_length + 1, head_size) + dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length + 1"} + else: + raise Exception("Unknown input or output name found") + return dynamic_axes + + +def get_merged_model_dynamic_axes(input_names: List[str], output_names: List[str]): + dynamic_axes = {} + for name in input_names + output_names: + if name in {"input_ids", "position_ids"}: + # shape is (batch_size, sequence_length) + dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"} + elif name == "attention_mask": + # shape is (batch_size, past_sequence_length + sequence_length) = (batch_size, total_sequence_length) + # for prompt generation, past_sequence_length = 0 + # for token generation, sequence_length = 1 + dynamic_axes[name] = {0: "batch_size", 1: "total_sequence_length"} + elif "past" in name: + # shape is (batch_size, num_heads, past_sequence_length, head_size) + dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length"} + elif name == "logits": + # shape is (batch_size, sequence_length, vocab_size) + dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"} + elif "present" in name: + # shape is (batch_size, num_heads, past_sequence_length + sequence_length, head_size) = (batch_size, num_heads, total_sequence_length, head_size) + # for prompt generation, past_sequence_length = 0 + # for token generation, sequence_length = 1 + dynamic_axes[name] = {0: "batch_size", 2: "total_sequence_length"} + else: + raise Exception("Unknown input or output name found") + return dynamic_axes + + +def save_onnx_model(onnx_model: onnx.ModelProto, output_path: str, data_path: str): + onnx.save( + onnx_model, + output_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=data_path, + size_threshold=1024, + convert_attribute=False, + ) + + +# Notes: +# 1) Dynamo export will not work automatically until this issue is resolved: https://github.com/microsoft/onnxscript/issues/493 +# +# 2) Dynamo export will run manually if you set the ONNX file path to the same path that you use to save the model after export. +# In other words, the value of `temp_path` should be set as the ONNX file path. You can open the issue in your browser to find +# the location in ONNX Script where you have to make this change. +# +# Once the issue is resolved, we hope to modify the code below as follows for each export. +# +# Before: +# temp_dir = args.output +# temp_path = os.path.join(temp_dir, "temp.onnx") +# ... +# ... +# ... +# del onnx_model +# os.system(f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}") +# +# +# After: +# temp_dir = tempfile.TemporaryDirectory() +# temp_path = os.path.join(temp_dir.name, "temp.onnx") +# ... +# ... +# ... +# del onnx_model +# temp_dir.cleanup() +# +def run_dynamo_export( + args: argparse.Namespace, l_config: AutoConfig, llama: AutoModelForCausalLM, rank: int = 0, world_size: int = 1 +): + from torch._dynamo import config + + config.capture_scalar_outputs = True + + # Dummy values for export + batch_size, sequence_length = 2, 8 + device = torch.device("cpu") + + # Export decoder_model.onnx + input_ids, attn_mask, pos_ids = get_sample_inputs(l_config, device, batch_size, sequence_length) + temp_dir = args.output # tempfile.TemporaryDirectory() + temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx") + torch.onnx.dynamo_export( + llama, input_ids, attn_mask, pos_ids, export_options=torch.onnx.ExportOptions(dynamic_shapes=True) + ).save(temp_path) + + # Check decoder_model.onnx and save all external data to one file + onnx.checker.check_model(temp_path) + onnx.shape_inference.infer_shapes_path(temp_path) + + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx") + onnx_model = onnx.load_model(temp_path, load_external_data=True) + save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data") + del onnx_model + os.system( + f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}" + ) # temp_dir.cleanup() + + # Export decoder_with_past_model.onnx + input_ids, attn_mask, pos_ids, past_kv = get_sample_with_past_kv_inputs( + l_config, device, batch_size, sequence_length, world_size=world_size + ) + temp_dir = args.output # tempfile.TemporaryDirectory() + temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx") + torch.onnx.dynamo_export( + llama, input_ids, attn_mask, pos_ids, past_kv, export_options=torch.onnx.ExportOptions(dynamic_shapes=True) + ).save(temp_path) + + # Check decoder_with_past_model.onnx and save all external data to one file + onnx.checker.check_model(temp_path) + onnx.shape_inference.infer_shapes_path(temp_path) + + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx") + onnx_model = onnx.load_model(temp_path, load_external_data=True) + save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data") + del onnx_model + os.system( + f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}" + ) # temp_dir.cleanup() + + logger.info(f"The {args.model_name} ONNX model has been successfully created with the Dynamo exporter!") + + +def _prepare_dir(dir_path): + if not os.path.exists(dir_path): + os.makedirs(dir_path) + + +def run_torchscript_separate_export( + args: argparse.Namespace, l_config: AutoConfig, llama: AutoModelForCausalLM, rank: int = 0, world_size: int = 1 +): + # Dummy values for export + batch_size, sequence_length = 2, 8 + + # set device used to export model + # for llama-2-70b we will use current gpus to speed up export process + # for other models, we will use CPU to make sure we have enough memory to do export + device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu") + + # Export decoder_model.onnx + decoder_inputs = get_sample_inputs(l_config, device, batch_size, sequence_length) + + input_names = ["input_ids", "attention_mask", "position_ids"] + output_names = [ + "logits", + *list( + chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(l_config.num_hidden_layers)) + ), + ] + dynamic_axes = get_model_dynamic_axes(input_names, output_names) + + # Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large. + # Use temp folder per rank to avoid race condition here. + temp_dir = f"./temp_{rank}" + _prepare_dir(temp_dir) + temp_path = os.path.join(temp_dir, "temp.onnx") + torch.onnx.export( + llama, + args=decoder_inputs, + f=temp_path, + export_params=True, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=13, + do_constant_folding=True, + verbose=args.verbose, + ) + + # Check decoder_model.onnx and save all external data to one file + onnx.checker.check_model(temp_path) + onnx.shape_inference.infer_shapes_path(temp_path) + + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx") + onnx_model = onnx.load_model(temp_path, load_external_data=True) + save_onnx_model( + onnx_model, + output_path, + f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data", + ) + del onnx_model + shutil.rmtree(temp_dir) + + # Export decoder_with_past_model.onnx + decoder_with_past_inputs = get_sample_with_past_kv_inputs( + l_config, + device, + batch_size, + sequence_length, + use_fp16=args.precision == Precision.FLOAT16, + world_size=world_size, + ) + input_names = [ + "input_ids", + "attention_mask", + "position_ids", + *list( + chain.from_iterable( + (f"past_key_values.{i}.key", f"past_key_values.{i}.value") for i in range(l_config.num_hidden_layers) + ) + ), + ] + output_names = [ + "logits", + *list( + chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(l_config.num_hidden_layers)) + ), + ] + dynamic_axes = get_model_with_past_kv_dynamic_axes(input_names, output_names) + + # Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large. + # Use temp folder per rank to avoid race condition here. + temp_dir = f"./temp_past_{rank}" + _prepare_dir(temp_dir) + temp_path = os.path.join(temp_dir, "temp.onnx") + torch.onnx.export( + llama, + args=decoder_with_past_inputs, + f=temp_path, + export_params=True, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=13, + do_constant_folding=True, + verbose=args.verbose, + ) + + # Check decoder_with_past_model.onnx and save all external data to one file + onnx.checker.check_model(temp_path) + onnx.shape_inference.infer_shapes_path(temp_path) + + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx") + onnx_model = onnx.load_model(temp_path, load_external_data=True) + save_onnx_model( + onnx_model, + output_path, + f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data", + ) + del onnx_model + shutil.rmtree(temp_dir) + + logger.info( + f"The {args.model_name} separate ONNX model has been successfully created with the TorchScript exporter!" + ) + + +def run_torchscript_merged_export( + args: argparse.Namespace, l_config: AutoConfig, llama: AutoModelForCausalLM, rank: int = 0, world_size: int = 1 +): + # Dummy values for export + batch_size, sequence_length, past_sequence_length = 2, 8, 0 + + # set device used to export model + # for llama-2-70b we will use current gpus to speed up export process + # for other models, we will use CPU to make sure we have enough memory to do export + device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu") + + temp_name = args.model_name.lower().replace("-", "").replace("_", "") + max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048 + + # Export decoder_merged_model.onnx + decoder_merged_inputs = get_merged_sample_with_past_kv_inputs( + l_config, + device, + batch_size, + sequence_length, + past_sequence_length, + max_seq_len=max_sequence_length, + use_fp16=args.precision == Precision.FLOAT16, + world_size=world_size, + ) + input_names = [ + "input_ids", + "attention_mask", + "position_ids", + *list( + chain.from_iterable( + (f"past_key_values.{i}.key", f"past_key_values.{i}.value") for i in range(l_config.num_hidden_layers) + ) + ), + ] + output_names = [ + "logits", + *list( + chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(l_config.num_hidden_layers)) + ), + ] + dynamic_axes = get_merged_model_dynamic_axes(input_names, output_names) + + # Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large. + # Use temp folder per rank to avoid race condition here. + temp_dir = f"./temp_{rank}" + _prepare_dir(temp_dir) + temp_path = os.path.join(temp_dir, "temp.onnx") + torch.onnx.export( + llama, + args=decoder_merged_inputs, + f=temp_path, + export_params=True, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=13, + do_constant_folding=True, + verbose=args.verbose, + ) + + # Check decoder_merged_model.onnx and save all external data to one file + onnx.checker.check_model(temp_path) + onnx.shape_inference.infer_shapes_path(temp_path) + + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx") + onnx_model = onnx.load_model(temp_path, load_external_data=True) + save_onnx_model( + onnx_model, + output_path, + f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx.data", + ) + del onnx_model + shutil.rmtree(temp_dir) + + logger.info(f"The {args.model_name} merged ONNX model has been successfully created with the TorchScript exporter!") + + +# Optimize the model as FP32 +def optimize_export(config: AutoConfig, input_path: str, output_path: str): + from fusion_options import FusionOptions + + optimization_options = FusionOptions("gpt2") + + model_opt = optimize_model( + input_path, + model_type="gpt2", + num_heads=config.num_attention_heads, + hidden_size=config.hidden_size, + opt_level=0, + optimization_options=optimization_options, + only_onnxruntime=False, + ) + model_opt.save_model_to_file(output_path, use_external_data_format=True) + logger.info(f"The ONNX model at {input_path} has been successfully optimized and saved at {output_path}!") + remove_existing_model(input_path) + + +def convert_to_float16( + args: argparse.Namespace, config: AutoConfig, old_paths: List[str], rank: int = 0, world_size: int = 1 +): + decoder_model_fp16_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp16.onnx") + decoder_with_past_model_fp16_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp16.onnx" + ) + decoder_merged_model_fp16_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp16.onnx" + ) + new_paths = [decoder_model_fp16_path, decoder_with_past_model_fp16_path, decoder_merged_model_fp16_path] + + logger.info("Converting to float16...") + for fp32_path, fp16_path in zip(old_paths, new_paths): + if os.path.exists(fp32_path): + model = OnnxModel(onnx.load_model(fp32_path, load_external_data=True)) + model.convert_float_to_float16(keep_io_types=False) + if args.use_gqa: + model = use_group_query_attention(config, model, world_size) + model.save_model_to_file(fp16_path, use_external_data_format=True) + del model + logger.info(f"The ONNX model at {fp32_path} has been converted to float16 and saved at {fp16_path}!") + remove_existing_model(fp32_path) + + logger.info(f"The {args.model_name} ONNX model has been successfully converted to float16!") + return new_paths + + +def use_group_query_attention(config: AutoConfig, fp16_model_opt: OnnxModel, world_size: int = 1): + # Replace MultiHeadAttention with GroupQueryAttention + fp16_model_opt = replace_mha_with_gqa(fp16_model_opt, "attention_mask", config.num_key_value_heads, world_size) + fp16_model_opt.prune_graph() + fp16_model_opt.update_graph(allow_remove_graph_inputs=True) + return fp16_model_opt + + +def smooth_quant( + args: argparse.Namespace, + decoder_model_fp32_path: str, + decoder_with_past_model_fp32_path: str, + decoder_model_int8_path: str, + decoder_with_past_model_int8_path: str, +): + from neural_compressor import PostTrainingQuantConfig + from neural_compressor import quantization as intel_quantization + from neural_compressor import set_workspace + from onnx.external_data_helper import load_external_data_for_model + from quant_kv_dataloader import QuantKVDataLoader + + set_workspace(args.nc_workspace) + quantization_config = PostTrainingQuantConfig( + calibration_sampling_size=[args.calibration_sampling_size], + recipes={ + "optypes_to_exclude_output_quant": ["MatMul"], + "smooth_quant": True, + "smooth_quant_args": {"alpha": args.smooth_quant_alpha}, + }, + op_type_dict={ + "^((?!(MatMul|Gather|Conv)).)*$": { + "weight": {"dtype": ["fp32"]}, + "activation": {"dtype": ["fp32"]}, + } + }, + ) + + # Convert decoder_model.onnx to INT8 + decoder_model_int8 = intel_quantization.fit( + decoder_model_fp32_path, + quantization_config, + calib_dataloader=QuantKVDataLoader(args), + ) + load_external_data_for_model( + decoder_model_int8._model, + os.path.split(decoder_model_int8._model_path)[0], + ) + save_onnx_model( + decoder_model_int8._model, + decoder_model_int8_path, + f"{args.model_name}_decoder_model_int8.onnx.data", + ) + del decoder_model_int8 + logger.info( + f"The ONNX model at {decoder_model_fp32_path} has been quantized to int8 and saved at {decoder_model_int8_path}!" + ) + remove_existing_model(decoder_model_fp32_path) + + # Convert decoder_with_past_model.onnx to INT8 + decoder_with_past_model_int8 = intel_quantization.fit( + decoder_with_past_model_fp32_path, + quantization_config, + calib_dataloader=QuantKVDataLoader(args, onnx_model_path=decoder_model_fp32_path), + ) + load_external_data_for_model( + decoder_with_past_model_int8._model, + os.path.split(decoder_with_past_model_int8._model_path)[0], + ) + save_onnx_model( + decoder_with_past_model_int8._model, + decoder_with_past_model_int8_path, + f"{args.model_name}_decoder_with_past_model_int8.onnx.data", + ) + del decoder_with_past_model_int8 + logger.info( + f"The ONNX model at {decoder_with_past_model_fp32_path} has been quantized to int8 and saved at {decoder_with_past_model_int8_path}!" + ) + remove_existing_model(decoder_with_past_model_fp32_path) + + logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") + + logger.warning(f"Removing {args.nc_workspace}") + shutil.rmtree(args.nc_workspace) + + +def remove_existing_model(model_path: str): + # Remove ONNX model and its external data + data_path = os.path.join(model_path + ".data") + os.remove(model_path) + os.remove(data_path) + logger.warning(f"Removed {model_path} and {data_path}") + + +def remove_existing_files(output_path: str): + for filename in os.listdir(output_path): + filepath = os.path.join(output_path, filename) + if ".onnx" in filename or ".onnx.data" in filename: + os.remove(filepath) + logger.warning(f"Removed {filepath}") + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-m", + "--model_name", + required=True, + help="Model name in Hugging Face", + ) + + parser.add_argument( + "-i", + "--input", + required=False, + default=os.path.join("."), + help="Directory path to PyTorch model and associated files if saved on disk", + ) + + parser.add_argument( + "-o", + "--output", + required=False, + default=os.path.join(".", "llama_onnx_models"), + help="Directory path to save exported model files in", + ) + + parser.add_argument( + "-p", + "--precision", + required=False, + type=Precision, + default=Precision.FLOAT32, + choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8, Precision.INT4], + help="Precision to export model in", + ) + + parser.add_argument( + "-e", + "--execution_provider", + required=False, + default="cpu", + choices=["cpu", "cuda", "rocm"], + help="Execution provider to verify parity with", + ) + + parser.add_argument( + "-r", + "--reexport", + required=False, + action="store_true", + help="Re-export models and overwrite existing models in output folder", + ) + parser.set_defaults(reexport=False) + + parser.add_argument( + "--use_gqa", + required=False, + action="store_true", + help="Use GroupQueryAttention instead of MultiHeadAttention", + ) + parser.set_defaults(use_gqa=False) + + parser.add_argument( + "--no_merged", + required=False, + action="store_true", + help="Export models into 2 ONNX files instead of 1. Deprecated in favor of exporting into 1 ONNX file.", + ) + parser.set_defaults(no_merged=False) + + parser.add_argument( + "-q", + "--quantization_method", + default="", + choices=["blockwise", "smooth_quant", "quantize_dynamic"], + help="Run a specific quantization algorithm (blockwise for int4, smooth_quant for int8, quantize_dynamic for int8). Blockwise is recommended. Need to install extra packages in `requirements-quant.txt` for SmoothQuant.", + ) + + blockwise_group = parser.add_argument_group("4-bit quantization") + + blockwise_group.add_argument( + "--block_size", + required=False, + default=32, + type=int, + help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py for details.", + ) + + smooth_quant_group = parser.add_argument_group("smooth_quant (8-bit quantization)") + + smooth_quant_group.add_argument( + "--smooth_quant_alpha", + required=False, + default=0.8, + type=float, + help="Strength to control migration difficulty from activation to weights. Default is 0.8 to match value \ + used in original paper for LLaMA. Paper recommends using values in [0.4, 0.6] range. \ + Link to paper: https://arxiv.org/pdf/2211.10438.pdf", + ) + + smooth_quant_group.add_argument( + "--smooth_quant_dataset", + required=False, + default="NeelNanda/pile-10k", + help="Path to dataset for calibration during quantization", + ) + + smooth_quant_group.add_argument( + "--pad_max", + required=False, + default=196, + type=int, + help="Max padding size", + ) + + smooth_quant_group.add_argument( + "--calibration_sampling_size", + required=False, + type=int, + default=8, + help="Calibration sampling size for quantization config", + ) + + smooth_quant_group.add_argument( + "--nc_workspace", + required=False, + type=str, + default=os.path.join(".", "nc_workspace"), + help="Workspace to save intermediate files generated by Intel's Neural Compressor package.", + ) + + quantize_dynamic_group = parser.add_argument_group("quantize_dynamic (8-bit quantization)") + + quantize_dynamic_group.add_argument( + "--quantize_embedding_layer", + required=False, + action="store_true", + help="Quantize MatMul, GEMM, and Gather.", + ) + quantize_dynamic_group.set_defaults(quantize_embedding_layer=False) + + quantize_dynamic_group.add_argument( + "--quantize_per_channel", + required=False, + action="store_true", + help="Quantize weights per each channel.", + ) + quantize_dynamic_group.set_defaults(quantize_per_channel=False) + + quantize_dynamic_group.add_argument( + "--quantize_reduce_range", + required=False, + action="store_true", + help="Quantize weights with 7 bits.", + ) + quantize_dynamic_group.set_defaults(quantize_reduce_range=False) + + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Print verbose logs", + ) + parser.set_defaults(verbose=False) + + parser.add_argument( + "-d", + "--use_dynamo_export", + action="store_true", + help="Use the new Dynamo exporter instead of the old TorchScript exporter", + ) + parser.set_defaults(use_dynamo_export=False) + + parser.add_argument( + "--cache_dir", + required=False, + type=str, + default="./model_cache", + help="model cache dir to override default HF cache dir to avoid overflood the /home dir", + ) + + args = parser.parse_args() + return args + + +def main(): + if version.parse(torch.__version__) < version.parse("2.2.0") and "2.2.0.dev" not in torch.__version__: + # Second predicate is for comparing nightly (ex: 2.2.0.dev20230920 vs 2.2.0) since first predicate is false + # in that scenario. It can be removed when torch v2.2.0 is released in stable. + logger.error(f"Detected PyTorch version {torch.__version__}. Please upgrade and use v2.2.0 or newer.") + return + + args = get_args() + setup_logger(args.verbose) + prepare_environment(args.input, args.output, args.execution_provider != "cpu") + if args.reexport: + remove_existing_files(args.output) + logger.info(f"Arguments: {args}") + + world_size = get_size() + rank = get_rank() + + # Load model and config + use_auth_token = args.input == os.path.join(".") + setattr(args, "use_auth_token", use_auth_token) # noqa: B010 + + original_model_name = args.model_name + setattr(args, "original_model_name", original_model_name) # noqa: B010 + args.model_name = args.model_name.split("/")[-1] + + setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010 + setattr(args, "device", torch.device(args.device_name)) # noqa: B010 + + location = args.original_model_name if use_auth_token else args.input + + # Use CUDA for LLaMA-2-70B to speed up export and CPU for other models + l_config, llama = setup_torch_model( + args, location, use_auth_token, device=args.device if args.model_name == "Llama-2-70b-hf" else None + ) + + assert l_config.num_attention_heads % world_size == 0 and l_config.num_key_value_heads % world_size == 0 + + barrier() + for i in range(world_size): + if i == rank: + # Set model paths for FP32 model + decoder_model_fp32_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx" + ) + decoder_with_past_model_fp32_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx" + ) + decoder_merged_model_fp32_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx" + ) + old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] + + missing_separate_exports = ( + args.no_merged + and not os.path.exists(decoder_model_fp32_path) + and not os.path.exists(decoder_with_past_model_fp32_path) + ) + missing_merged_export = not args.no_merged and not os.path.exists(decoder_merged_model_fp32_path) + + # Export to ONNX + if missing_separate_exports or missing_merged_export: + if args.use_dynamo_export and missing_separate_exports: + logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.") + logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/") + logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/") + logger.warning( + "Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script" + ) + logger.warning( + "Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step." + ) + run_dynamo_export(args, l_config, llama) + elif args.no_merged: + run_torchscript_separate_export(args, l_config, llama, rank, world_size) + else: + run_torchscript_merged_export(args, l_config, llama, rank, world_size) + del llama # Delete LLaMA model from memory since it will be loaded again during parity check + + # Set model paths to store FP32 optimized model + decoder_model_fp32_opt_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32_opt.onnx" + ) + decoder_with_past_model_fp32_opt_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32_opt.onnx" + ) + decoder_merged_model_fp32_opt_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32_opt.onnx" + ) + new_paths = [ + decoder_model_fp32_opt_path, + decoder_with_past_model_fp32_opt_path, + decoder_merged_model_fp32_opt_path, + ] + + # Run the optimizer script + logger.info("Optimizing models...") + for orig_path, opt_path in zip(old_paths, new_paths): + if os.path.exists(orig_path): + optimize_export(l_config, input_path=orig_path, output_path=opt_path) + + # Re-assign default FP32 model paths as their optimized versions + decoder_model_fp32_path = decoder_model_fp32_opt_path + decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path + decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path + old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] + + logger.info( + f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!" + ) + + # Change precision of exported models from FP32 + if args.precision == Precision.FLOAT16: + new_paths = convert_to_float16(args, l_config, old_paths, rank, world_size) + + elif args.precision == Precision.INT8: + decoder_model_int8_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx" + ) + decoder_with_past_model_int8_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int8.onnx" + ) + decoder_merged_model_int8_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx" + ) + new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path] + + if args.quantization_method == "smooth_quant": + if not args.no_merged: + logger.error("SmoothQuant must be used on separately exported models") + else: + logger.info( + f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8" + ) + smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1]) + + elif args.quantization_method == "quantize_dynamic": + logger.warning( + "The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`." + ) + + logger.info("Quantizing to int8...") + for fp32_path, int8_path in zip(old_paths, new_paths): + if os.path.exists(fp32_path): + ort_quantization.quantize_dynamic( + fp32_path, + int8_path, + op_types_to_quantize=["MatMul", "Gemm", "Gather"] + if args.quantize_embedding_layer + else ["MatMul", "Gemm"], + per_channel=args.quantize_per_channel, + reduce_range=args.quantize_reduce_range, + use_external_data_format=True, + extra_options={"MatMulConstBOnly": True}, + ) + logger.info( + f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!" + ) + remove_existing_model(decoder_model_fp32_path) + + logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") + + else: + raise Exception(f"Could not recognize {args.quantization_method} as a quantization method") + + elif args.precision == Precision.INT4: + if args.execution_provider != "cpu": + old_paths = convert_to_float16(args, l_config, old_paths, rank, world_size) + + decoder_model_int4_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx" + ) + decoder_with_past_model_int4_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int4.onnx" + ) + decoder_merged_model_int4_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx" + ) + new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path] + + for fp_path, int4_path in zip(old_paths, new_paths): + if os.path.exists(fp_path): + model = onnx.load_model(fp_path, load_external_data=True) + quant = MatMul4BitsQuantizer(model, args.block_size, is_symmetric=True, nodes_to_exclude=[]) + quant.process() + quant.model.save_model_to_file(int4_path, use_external_data_format=True) + del model + del quant + logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") + remove_existing_model(fp_path) + barrier() + + logger.info("Verifying parity on all ONNX models created") + + # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models + args.precision = ( + "fp32" + if args.precision in {Precision.INT8, Precision.FLOAT32} + or (args.precision == Precision.INT4 and args.execution_provider == "cpu") + else "fp16" + ) + + # Verify parity on all saved ONNX models + for filename in os.listdir(args.output): + if ( + ".data" in filename + or ".onnx" not in filename + or args.precision not in filename + or f"rank_{rank}" not in filename + ): + continue + + parity_cmd = [ + "-m", + original_model_name, + "-o", + os.path.join(args.output, filename), + "-ep", + args.execution_provider, + "-fp", + args.precision, + "--cache_dir", + args.cache_dir, + ] + if "with_past" in filename: + parity_cmd.append("--use_past_kv") + if "merged" in filename: + parity_cmd.append("--merged") + if args.use_gqa: + parity_cmd.append("--use_gqa") + + try: + logger.debug(f"check parity with cmd: {parity_cmd}") + parity_check(parity_cmd) + except Exception as e: + logger.warning(f"An error occurred while verifying parity: {e}", exc_info=True) + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/python/tools/transformers/models/llama/dist_settings.py b/onnxruntime/python/tools/transformers/models/llama/dist_settings.py new file mode 100644 index 0000000000..50b0669d6d --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/dist_settings.py @@ -0,0 +1,45 @@ +import os + +import torch.distributed as dist + +comm = None + + +def init_dist(): + if "LOCAL_RANK" in os.environ: + int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7645", world_size=world_size, rank=rank) + elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: + from mpi4py import MPI + + comm = MPI.COMM_WORLD # noqa: F841 + + int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", 0)) + rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0)) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1)) + + dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7647", world_size=world_size, rank=rank) + else: + # don't need to do init for single process + pass + + +def get_rank(): + return comm.Get_rank() if comm is not None else 0 + + +def get_size(): + return comm.Get_size() if comm is not None else 1 + + +def barrier(): + if comm is not None: + comm.Barrier() + + +def print_out(*args): + if get_rank() == 0: + print(*args) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py new file mode 100644 index 0000000000..bae1ae82e8 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -0,0 +1,312 @@ +from typing import List, Tuple + +import numpy as np +import torch +from transformers import AutoConfig + +from onnxruntime import InferenceSession, OrtValue + + +# Get position_ids from attention_mask +def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool): + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if use_past_kv: + # Shape: (batch_size, 1) + position_ids = position_ids[:, -1].unsqueeze(-1) + + # Shape: (batch_size, sequence_length) + return position_ids + + +# Inputs for first pass to get initial past_key_values +# input_ids: (batch_size, sequence_length) +# attention_mask: (batch_size, sequence_length) +# position_ids: (batch_size, sequence_length) +def get_sample_inputs( + config: AutoConfig, + device: torch.device, + batch_size: int, + seq_len: int, + engine: str = "pt", + return_dict: bool = False, +): + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) + attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64) + position_ids = get_position_ids(attention_mask, use_past_kv=False) + + # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) + input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) + attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device) + position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device) + + if not return_dict: + # For export + return (input_ids, attention_mask, position_ids) + + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + } + return inputs + + +# Inputs for subsequent passes with past_key_values +# input_ids: (batch_size, 1) +# attention_mask: (batch_size, past_sequence_length + 1) +# position_ids: (batch_size, 1) +# past_key: (batch_size, num_heads, past_sequence_length, head_size) +# past_value: (batch_size, num_heads, past_sequence_length, head_size) +def get_sample_with_past_kv_inputs( + config: AutoConfig, + device: torch.device, + batch_size: int, + past_seq_len: int, + use_fp16: bool = False, + engine: str = "pt", + return_dict: bool = False, + world_size: int = 1, +): + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64) + attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64) + # position_ids is of shape (batch_size, 1) + position_ids = get_position_ids(attention_mask, use_past_kv=True) + past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size) + + # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) + input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) + attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device) + position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device) + past_kv = ( + flatten_past_kv_inputs(past_kv) + if engine == "ort" + else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv)) + ) + + if not return_dict: + # For export + assert isinstance(past_kv, list) + return (input_ids, attention_mask, position_ids, past_kv) + + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + } + if engine == "ort": + assert isinstance(past_kv, dict) + inputs.update(past_kv) + else: + assert isinstance(past_kv, list) + inputs["past_key_values"] = past_kv + + return inputs + + +# Inputs for all passes with past_key_values +# input_ids: (batch_size, sequence_length) +# attention_mask: (batch_size, past_sequence_length + sequence_length) +# position_ids: (batch_size, sequence_length) +# past_key: (batch_size, num_heads, kv_sequence_length, head_size) +# For models with GQA, kv_sequence_length = max_sequence_length +# For models without GQA, kv_sequence_length = past_sequence_length +# past_value: (batch_size, num_heads, kv_sequence_length, head_size) +# For models with GQA, kv_sequence_length = max_sequence_length +# For models without GQA, kv_sequence_length = past_sequence_length +def get_merged_sample_with_past_kv_inputs( + config: AutoConfig, + device: torch.device, + batch_size: int, + seq_len: int, + past_seq_len: int, + max_seq_len: int, + use_fp16: bool = False, + use_gqa: bool = False, + engine: str = "pt", + return_dict: bool = False, + world_size: int = 1, +): + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) + attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64) + # position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation + position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0)) + past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size) + + # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) + input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) + attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device) + position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device) + past_kv = ( + flatten_past_kv_inputs(past_kv) + if engine == "ort" + else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv)) + ) + + if not return_dict: + # For export + assert isinstance(past_kv, list) + return (input_ids, attention_mask, position_ids, past_kv) + + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + } + if engine == "ort": + assert isinstance(past_kv, dict) + inputs.update(past_kv) + + if use_gqa: + inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len) + + else: + assert isinstance(past_kv, list) + inputs["past_key_values"] = past_kv + + return inputs + + +# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx +def get_msft_sample_inputs( + config: AutoConfig, + batch_size: int, + past_seq_len: int, + seq_len: int, + max_seq_len: int, + use_fp16: bool, + use_gqa: bool, + split_kv: bool, +): + np_dtype = np.float16 if use_fp16 else np.float32 + head_size = config.hidden_size // config.num_attention_heads + + if not split_kv: + ort_inputs = { + "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), + "attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype), + "k_cache": np.random.rand( + batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size + ).astype(np_dtype), + "v_cache": np.random.rand( + batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size + ).astype(np_dtype), + "pos": np.array(past_seq_len, dtype=np.int64), + } + else: + ort_inputs = { + "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), + "attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype( + np.int32 + ), + "pos": np.array(past_seq_len, dtype=np.int64), + } + for i in range(config.num_hidden_layers): + ort_inputs.update( + { + f"k_{i}_cache": np.random.rand( + batch_size, config.num_attention_heads, past_seq_len, head_size + ).astype(np_dtype), + f"v_{i}_cache": np.random.rand( + batch_size, config.num_attention_heads, past_seq_len, head_size + ).astype(np_dtype), + } + ) + + if use_gqa: + ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len) + + return ort_inputs + + +# Create past_key_values +# Each is of shape (batch_size, num_heads, past_sequence_length, head_size) +def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1): + num_heads, head_size = config.num_key_value_heads // world_size, config.hidden_size // config.num_attention_heads + torch_dtype = torch.float16 if use_fp16 else torch.float32 + past_kv = [ + ( + torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype), + torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype), + ) + for _ in range(config.num_hidden_layers) + ] + return past_kv + + +# Convert list of past_key_values to dict of past_key and past_value +def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]]): + past_kv = {} + for i, (past_k, past_v) in enumerate(past_key_values): + past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy() + past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy() + return past_kv + + +# Format PyTorch inputs to ONNX Runtime inputs +def convert_inputs_for_ort( + pt_inputs: dict, + use_gqa: bool = False, + past_seq_len: int = 0, + max_seq_len: int = 2048, + device: str = "", + device_id: int = -1, +): + ort_inputs = {} + for k, v in pt_inputs.items(): + if isinstance(v, np.ndarray): + ort_inputs[k] = v + elif k == "past_key_values": + ort_inputs.update(flatten_past_kv_inputs(v)) + else: + ort_inputs[k] = v.detach().cpu().numpy() + + # Reshape KV caches if using past-present-share-buffer + if use_gqa and device != "" and device != "cpu" and device_id > -1: + ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len) + + return ort_inputs + + +def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_seq_len: int): + for k, v in ort_inputs.items(): + # Allocate new buffers with max_sequence_length for GQA + if "cache" in k or "past_key_values" in k: + # Copy v (BxSxPxH) into new_v (BxSxMxH) + batch_size, num_heads, _, head_size = v.shape + new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype) + new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v + ort_inputs[k] = new_v + return ort_inputs + + +# Add IO bindings for execution providers +def add_io_bindings( + model: InferenceSession, ort_inputs: dict, device: str, device_id: int, use_gqa: bool, kv_cache_ortvalues: dict +): + io_binding = model.io_binding() + + for k, v in ort_inputs.items(): + # Bind OrtValue inputs to device + if use_gqa and ("cache" in k or "past_key_values" in k): + if k not in kv_cache_ortvalues: + v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id) + io_binding.bind_ortvalue_input(k, v_device) + kv_cache_ortvalues[k] = v_device + else: + kv_cache_ortvalues[k].update_inplace(v) + io_binding.bind_ortvalue_input(k, kv_cache_ortvalues[k]) + else: + v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id) + io_binding.bind_ortvalue_input(k, v_device) + + for output in model.get_outputs(): + name = output.name + if use_gqa and ("out" in name or "present" in name): + # Bind present KV cache outputs to past KV cache inputs in order to buffer share + input_name = name.replace("out", "cache").replace("present", "past_key_values") + io_binding.bind_ortvalue_output(name, kv_cache_ortvalues[input_name]) + else: + io_binding.bind_output(name, device_type=device, device_id=device_id) + + return io_binding, kv_cache_ortvalues diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py new file mode 100644 index 0000000000..418a65325c --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -0,0 +1,274 @@ +import argparse +import logging +import os +import time +from typing import List + +import numpy as np +import torch +from benchmark_helper import setup_logger +from dist_settings import get_rank, get_size +from llama_inputs import ( + add_io_bindings, + convert_inputs_for_ort, + get_merged_sample_with_past_kv_inputs, + get_sample_inputs, + get_sample_with_past_kv_inputs, +) +from llama_torch import setup_torch_model +from transformers import AutoConfig, AutoModelForCausalLM + +import onnxruntime as ort + +logger = logging.getLogger("") + + +def get_sequence_lengths(args: argparse.Namespace): + past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8) + temp_name = args.model_name.lower().replace("-", "").replace("_", "") + max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048 + return past_sequence_length, curr_sequence_length, max_sequence_length + + +def get_inputs(args: argparse.Namespace, config: AutoConfig): + # Dummy values for parity + world_size = get_size() + batch_size = 2 + past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args) + + if args.merged: + inputs = get_merged_sample_with_past_kv_inputs( + config, + args.device, + batch_size, + seq_len=sequence_length, + past_seq_len=past_sequence_length, + max_seq_len=max_sequence_length, + use_fp16=args.use_fp16, + use_gqa=args.use_gqa, + return_dict=True, + world_size=world_size, + ) + elif args.use_past_kv: + inputs = get_sample_with_past_kv_inputs( + config, + args.device, + batch_size, + sequence_length, + use_fp16=args.use_fp16, + return_dict=True, + world_size=world_size, + ) + else: + inputs = get_sample_inputs(config, args.device, batch_size, sequence_length, return_dict=True) + + return inputs + + +def verify_parity( + args: argparse.Namespace, config: AutoConfig, pt_model: AutoModelForCausalLM, kv_cache_ortvalues: dict +): + inputs = get_inputs(args, config) + + # Run inference with PyTorch + if args.execution_provider != "cpu": + torch.cuda.synchronize() + start_time = time.time() + pt_outputs = pt_model(**inputs).logits.detach().cpu().numpy() + if args.execution_provider != "cpu": + torch.cuda.synchronize() + end_time = time.time() + logger.info(f"PyTorch took {end_time - start_time} s") + del pt_model + + # Run inference with ORT + past_sequence_length, _, max_sequence_length = get_sequence_lengths(args) + inputs = convert_inputs_for_ort( + inputs, + use_gqa=args.use_gqa, + past_seq_len=past_sequence_length, + max_seq_len=max_sequence_length, + device=args.execution_provider, + device_id=int(args.rank), + ) + + ep = f"{args.execution_provider.upper()}ExecutionProvider" + if ep == "CUDAExecutionProvider": + ep = (ep, {"device_id": args.rank}) + ort_model = ort.InferenceSession( + args.onnx_model_path, + sess_options=ort.SessionOptions(), + providers=[ep], + ) + + # Add IO bindings for non-CPU execution providers + if args.execution_provider != "cpu": + io_binding, kv_cache_ortvalues = add_io_bindings( + ort_model, + inputs, + args.execution_provider, + int(args.rank), + args.use_gqa, + kv_cache_ortvalues, + ) + + io_binding.synchronize_inputs() + start_time = time.time() + ort_model.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() + end_time = time.time() + + ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits + del ort_model + + else: + start_time = time.time() + ort_outputs = ort_model.run(None, inputs) + end_time = time.time() + + ort_outputs = ort_outputs[0] # Get logits + + logger.info(f"ONNX Runtime took {end_time - start_time} s") + + # Compare PyTorch and ONNX Runtime accuracy + tol = 2e1 if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path else 5e-1 + parity = np.allclose(pt_outputs, ort_outputs, rtol=tol, atol=tol) + logger.warning(f"Are PyTorch and ONNX Runtime results close? {parity}") + if not parity: + logger.warning(f"Max diff: {np.max(pt_outputs - ort_outputs)}") + return kv_cache_ortvalues + + +def get_args(argv: List[str]): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-m", + "--model_name", + required=True, + help="Model name in Hugging Face", + ) + + parser.add_argument( + "-t", + "--torch_model_directory", + required=False, + default=os.path.join("."), + help="Path to folder containing PyTorch model and associated files if saved on disk", + ) + + parser.add_argument( + "-o", + "--onnx_model_path", + required=True, + default=os.path.join("."), + help="Path to ONNX model (with external data files saved in the same folder as the model)", + ) + + parser.add_argument( + "-ep", + "--execution_provider", + required=False, + default="cpu", + choices=["cpu", "cuda", "rocm"], + help="Execution provider to verify parity with", + ) + + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Print verbose logs", + ) + parser.set_defaults(verbose=False) + + parser.add_argument( + "-p", + "--use_past_kv", + action="store_true", + help="Use past key and past value as inputs to the model. Necessary for decoder_with_past_model.onnx models.", + ) + parser.set_defaults(use_past_kv=False) + + parser.add_argument( + "-g", + "--use_gqa", + action="store_true", + help="Use if model has GroupQueryAttention", + ) + parser.set_defaults(use_gqa=False) + + parser.add_argument( + "--merged", + action="store_true", + help="Use merged model (i.e. decoder_merged_model.onnx).", + ) + parser.set_defaults(merged=False) + + parser.add_argument( + "-fp", + "--precision", + required=True, + choices=["int4", "int8", "fp16", "fp32"], + help="Precision of model", + ) + + parser.add_argument( + "--cache_dir", + required=False, + type=str, + default="./model_cache", + help="model cache dir to override default HF cache dir to avoid overflood the /home dir", + ) + + args = parser.parse_args() if argv == [] else parser.parse_args(argv) + + # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models + args.precision = ( + "fp32" + if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.execution_provider == "cpu") + else "fp16" + ) + return args + + +def main(argv: List[str] = []): # noqa: B006 + args = get_args(argv) + setup_logger(args.verbose) + logger.info(f"Arguments: {args}") + rank = get_rank() + + # Load model and config + setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010 + args.rank = rank + setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010 + setattr(args, "device", torch.device(args.device_name)) # noqa: B010 + use_auth_token = args.torch_model_directory == os.path.join(".") + location = args.model_name if use_auth_token else args.torch_model_directory + + config, llama = setup_torch_model( + args, + location, + use_auth_token, + torch_dtype=(torch.float16 if args.use_fp16 else torch.float32), + device=args.device, + ) + + kv_cache_ortvalues = {} + if not args.merged: + verify_parity(args, config, llama, kv_cache_ortvalues) + else: + # Verify prompt generation in merged model (decoder_model.onnx) + args.use_past_kv = False + kv_cache_ortvalues = verify_parity(args, config, llama, kv_cache_ortvalues) + + # Verify token generation in merged model (decoder_with_past_model.onnx) + args.use_past_kv = True + verify_parity(args, config, llama, kv_cache_ortvalues) + + +if __name__ == "__main__": + seed = 2 + np.random.seed(seed) + torch.manual_seed(seed) + main() diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py new file mode 100644 index 0000000000..94e0397116 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py @@ -0,0 +1,38 @@ +import logging +import os + +import torch +from dist_settings import barrier, get_rank, get_size +from transformers import AutoConfig, AutoModelForCausalLM + +logger = logging.getLogger("") + + +def setup_torch_model(args, location, use_auth_token, torch_dtype=torch.float32, device=None): + world_size = get_size() + logger.info(f"world_size: {world_size}") + rank = get_rank() + barrier() + + if not os.path.exists(args.cache_dir): + os.makedirs(args.cache_dir, exist_ok=True) + + for i in range(world_size): + if i == rank % (world_size): + l_config = AutoConfig.from_pretrained(location, use_auth_token=use_auth_token, cache_dir=args.cache_dir) + l_config.use_cache = True + llama = AutoModelForCausalLM.from_pretrained( + location, + use_auth_token=use_auth_token, + config=l_config, + torch_dtype=torch_dtype, + cache_dir=args.cache_dir, + ) + if world_size > 1: + llama.parallel_model() + if device: + llama.to(device) + llama.eval() + llama.requires_grad_(False) + barrier() + return l_config, llama diff --git a/onnxruntime/python/tools/transformers/models/llama/quant_kv_dataloader.py b/onnxruntime/python/tools/transformers/models/llama/quant_kv_dataloader.py new file mode 100644 index 0000000000..e8b5632610 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/quant_kv_dataloader.py @@ -0,0 +1,103 @@ +import argparse + +import numpy as np +import torch +from benchmark_helper import create_onnxruntime_session +from datasets import load_dataset +from llama_inputs import get_position_ids +from torch.nn.functional import pad +from torch.utils.data import DataLoader +from transformers import LlamaTokenizer + + +class QuantKVDataLoader: + def __init__(self, args: argparse.Namespace, onnx_model_path: str = ""): + self.batch_size = 1 + self.pad_max = args.pad_max + + tokenizer = LlamaTokenizer.from_pretrained(args.original_model_name, use_auth_token=args.use_auth_token) + dataset = load_dataset(args.smooth_quant_dataset, split="train") + dataset = dataset.map(lambda examples: tokenizer(examples["text"]), batched=True) + dataset.set_format(type="torch", columns=["input_ids", "attention_mask"]) + + self.dataloader = DataLoader( + dataset, + batch_size=self.batch_size, + shuffle=False, + collate_fn=self.collate_batch, + ) + self.decoder_model = ( + create_onnxruntime_session( + onnx_model_path, + args.execution_provider != "cpu", # use_gpu + provider=args.execution_provider, + verbose=args.verbose, + ) + if onnx_model_path + else None + ) + + def collate_batch(self, batch): + input_ids_batched = [] + attention_mask_batched = [] + position_ids_batched = [] + labels = [] + + for text in batch: + # Set inputs for model + input_ids = text["input_ids"] + attention_mask = torch.ones(len(input_ids)) + position_ids = get_position_ids(attention_mask, use_past_kv=False) + label = len(input_ids) - 1 + + # Pad input data because all model inputs must have same shape + pad_len = self.pad_max - input_ids.shape[0] + input_ids = pad(input_ids, (0, pad_len), value=1) + attention_mask = pad(attention_mask, (0, pad_len), value=0) + position_ids = pad(position_ids, (0, pad_len), value=0) + + input_ids_batched.append(input_ids) + attention_mask_batched.append(attention_mask) + position_ids_batched.append(position_ids) + labels.append(label) + + input_ids_batched = torch.vstack(input_ids_batched) + attention_mask_batched = torch.vstack(attention_mask_batched) + position_ids_batched = torch.vstack(position_ids_batched) + labels = torch.tensor(labels) + + return (input_ids_batched, attention_mask_batched, position_ids_batched), labels + + def __iter__(self): + try: + for (input_ids, attention_mask, position_ids), labels in self.dataloader: + # Inputs for decoder_model.onnx + inputs = { + "input_ids": input_ids[:, :-1].detach().cpu().numpy().astype(np.int64), + "attention_mask": attention_mask[:, :-1].detach().cpu().numpy().astype(np.int64), + "position_ids": position_ids[:, :-1].detach().cpu().numpy().astype(np.int64), + } + label = labels.detach().cpu().numpy() + + if self.decoder_model is not None: + # Run decoder_model.onnx to get inputs for decoder_with_past_model.onnx + outputs = self.decoder_model.run(None, inputs) + + for i in range(int((len(outputs) - 1) / 2)): + inputs[f"past_key_values.{i}.key"] = outputs[i * 2 + 1] + inputs[f"past_key_values.{i}.value"] = outputs[i * 2 + 2] + past_sequence_length = inputs["past_key_values.0.key"].shape[2] + + inputs["input_ids"] = input_ids[:, -1].unsqueeze(0).detach().cpu().numpy().astype(np.int64) + attn_mask_torch = torch.ones((self.batch_size, past_sequence_length + 1), dtype=torch.int64) + inputs["attention_mask"] = attn_mask_torch.detach().cpu().numpy().astype(np.int64) + inputs["position_ids"] = ( + get_position_ids(attn_mask_torch, use_past_kv=True).detach().cpu().numpy().astype(np.int64) + ) + + # Yield (inputs, label) tuple for Intel's Neural Compressor: + # https://github.com/intel/neural-compressor/blob/d4baed9ea11614e1f0dc8a1f4f55b73ed3ed585c/neural_compressor/quantization.py#L55-L62 + yield (inputs, label) + + except StopIteration: + return diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-70b-model.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-70b-model.txt new file mode 100644 index 0000000000..572cfdb71b --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/requirements-70b-model.txt @@ -0,0 +1,4 @@ +-r requirements.txt +git+https://github.com/frankdongms/transformers.git@frdong/shard_llama +mpi4py +psutil \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt new file mode 100644 index 0000000000..3d707fa13e --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt @@ -0,0 +1,2 @@ +-r requirements.txt +onnxruntime>=1.16.2 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt new file mode 100644 index 0000000000..b634bcc50f --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt @@ -0,0 +1,4 @@ +-r requirements.txt +# Please manually install torch>=2.2.0.dev20230920 with CUDA enabled for the CUDA version installed in your system. +# Instructions can be found here: https://pytorch.org/get-started/locally/ +onnxruntime-gpu>=1.16.2 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-quant.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-quant.txt new file mode 100644 index 0000000000..890e636428 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/requirements-quant.txt @@ -0,0 +1,2 @@ +-r requirements-cpu.txt +neural-compressor>=2.2.1 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt new file mode 100644 index 0000000000..4210f36982 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt @@ -0,0 +1,6 @@ +git+https://github.com/huggingface/optimum.git +transformers>=4.33.2 +torch>=2.2.0.dev20230920 +onnx>=1.14.0 +datasets>=2.8.0 +protobuf==3.20.2 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index d184224317..d937e3f421 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -72,37 +72,52 @@ cd onnxruntime/onnxruntime/python/tools/transformers/models/stable_diffusion Below is an example to optimize Stable Diffusion 1.5 in Linux. For Windows OS, please change the format of path to be like `.\sd` instead of `./sd`. +It is recommended to create a Conda environment with Python 3.10 for the following setup: +``` +conda create -n py310 python=3.10 +conda activate py310 +``` + ### Setup Environment (CUDA) -It is recommended to create a Conda environment with Python 3.8, 3.9 or 3.10, and run the model with [CUDA 11.7](https://developer.nvidia.com/cuda-11-7-0-download-archive) or 11.8. +First, we need install CUDA 11.8 or 12.1, [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html) 8.5 or above, and [TensorRT 8.6.1](https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html) in the machine. + +#### CUDA 11.8: + +In the Conda environment, install PyTorch 2.1 or above, and other required packages like the following: ``` -conda create -n py38 python=3.8 -conda activate py38 -pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 -pip install -r requirements-cuda.txt +pip install torch --index-url https://download.pytorch.org/whl/nightly/cu118 +pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com +pip install -r requirements-cuda11.txt ``` -ONNX Runtime requires CUDA and [cuDNN](https://developer.nvidia.com/rdp/cudnn-download) for GPU inference. CUDA 11.7 and cuDNN 8.5 are used in our tests. +We cannot directly `pip install tensorrt` for CUDA 11. Follow https://github.com/NVIDIA/TensorRT/issues/2773 to install TensorRT for CUDA 11 in Linux. For Windows, pip install the tensorrt wheel in the downloaded TensorRT zip file instead. -#### Install Nightly (Optional) +#### CUDA 12.*: +The official package of onnxruntime-gpu 1.16.* is built for CUDA 11.8. To use CUDA 12.*, you will need [build onnxruntime from source](https://onnxruntime.ai/docs/build/inferencing.html). -Skip this step if you use onnxruntime-gpu package from official releases. - -To try latest optimizations, you can install [ort-nightly-gpu](https://aiinfra.visualstudio.com/PublicPackages/_artifacts/feed/ORT-Nightly/PyPI/ort-nightly-gpu/) package like the following: +``` +git clone --recursive https://github.com/Microsoft/onnxruntime.git +cd onnxruntime +pip install -r requirements-dev.txt +``` +Follow [example script for A100 in Ubuntu](https://github.com/microsoft/onnxruntime/blob/26a7b63716e3125bfe35fe3663ba10d2d7322628/build_release.sh) +or [example script for RTX 4090 in Windows](https://github.com/microsoft/onnxruntime/blob/8df5f4e0df1f3b9ceeb0f1f2561b09727ace9b37/build_trt.cmd) to build and install onnxruntime-gpu wheel. +Then install other python packages like the following: ``` -pip uninstall onnxruntime-gpu -pip install ort-nightly-gpu -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ +pip install torch --index-url https://download.pytorch.org/whl/nightly/cu121 +pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com +pip install -r requirements-cuda12.txt ``` +Finally, `pip install tensorrt` for Linux. For Windows, pip install the tensorrt wheel in the downloaded TensorRT zip file instead. ### Setup Environment (ROCm) -It is recommended that the users run the model with ROCm 5.4 or newer and Python 3.8, 3.9 or 3.10. +It is recommended that the users run the model with ROCm 5.4 or newer and Python 3.10. Note that Windows is not supported for ROCm at the moment. ``` -conda create -n py38 python=3.8 -conda activate py38 wget https://repo.radeon.com/rocm/manylinux/rocm-rel-5.4/torch-1.12.1%2Brocm5.4-cp38-cp38-linux_x86_64.whl pip install torch-1.12.1+rocm5.4-cp38-cp38-linux_x86_64.whl pip install -r requirements-rocm.txt @@ -154,6 +169,12 @@ curl https://raw.githubusercontent.com/huggingface/diffusers/v0.15.1/scripts/con python convert_sd_onnx.py --model_path runwayml/stable-diffusion-v1-5 --output_path ./sd_v1_5/fp32 ``` +For SDXL, use optimum to export the model: +``` +pip install optimum diffusers onnx onnxruntime-gpu +optimum-cli export onnx --model stabilityai/stable-diffusion-xl-base-1.0 --task stable-diffusion-xl ./sd_xl_base_onnx +``` + ### Optimize ONNX Pipeline Example to optimize the exported float32 ONNX models, and save to float16 models: @@ -161,7 +182,10 @@ Example to optimize the exported float32 ONNX models, and save to float16 models python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i ./sd_v1_5/fp32 -o ./sd_v1_5/fp16 --float16 ``` -If you installed ONNX Runtime v1.14, some optimizations (packed QKV and BiasAdd) will be disabled automatically since they are not available in v1.14. +For SDXL model, it is recommended to use a machine with 32 GB or more memory to optimize. +``` +python optimize_pipeline.py -i ./sd_xl_base_onnx -o ./sd_xl_base_fp16 --float16 +``` ### Run Benchmark @@ -224,18 +248,21 @@ Sometime, it complains ptxas not found when there are multiple CUDA versions ins Note that torch.compile is not supported in Windows: we encountered error `Windows not yet supported for torch.compile`. So it is excluded from RTX 3060 results of Windows. -### Run Benchmark with TensorRT and TensorRT execution provider +### Run Benchmark with TensorRT or TensorRT execution provider For TensorRT installation, follow https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html. ``` pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 -pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com +pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com pip install -r requirements-tensorrt.txt export CUDA_MODULE_LOADING=LAZY python benchmark.py -e tensorrt -b 1 -v 1.5 python benchmark.py -e onnxruntime -r tensorrt -b 1 -v 1.5 python benchmark.py -e onnxruntime -r tensorrt -b 1 -v 1.5 --enable_cuda_graph + +python benchmark.py -e tensorrt --height 1024 --width 1024 -s 30 -b 1 -v xl-1.0 --enable_cuda_graph +python benchmark.py -e onnxruntime -r tensorrt --height 1024 --width 1024 -s 30 -b 1 -v xl-1.0 --enable_cuda_graph ``` ### Example Benchmark output @@ -384,8 +411,7 @@ Some kernels are enabled by MIOpen. We hereby thank for the AMD developers' coll There are other optimizations might improve the performance or reduce memory footprint: * Export the whole pipeline into a single ONNX model. Currently, there are multiple ONNX models (CLIP, VAE and U-Net etc). Each model uses separated thread pool and memory allocator. Combine them into one model could share thread pool and memory allocator. The end result is more efficient and less memory footprint. -* For Stable Diffusion 2.1, we disable TensorRT flash attention kernel and use only memory efficient attention. It is possible to add flash attention using Triton compiler to improve performance. +* For Stable Diffusion 2.1, we disable TensorRT flash attention kernel and use only memory efficient attention. It is possible to add flash attention in Windows to improve performance. * Reduce GPU memory footprint by actively deleting buffers for intermediate results. -* Attention fusion in CLIP * Safety Checker Optimization * Leverage FP8 in latest GPU diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index 13126f648d..1f1db914e2 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -10,15 +10,18 @@ import sys import time +import __init__ # noqa: F401. Walk-around to run this script directly import coloredlogs # import torch before onnxruntime so that onnxruntime uses the cuDNN in the torch package. import torch +from benchmark_helper import measure_memory SD_MODELS = { "1.5": "runwayml/stable-diffusion-v1-5", "2.0": "stabilityai/stable-diffusion-2", "2.1": "stabilityai/stable-diffusion-2-1", + "xl-1.0": "stabilityai/stable-diffusion-xl-refiner-1.0", } PROVIDERS = { @@ -43,139 +46,13 @@ def example_prompts(): "delicate elvish moonstone necklace on a velvet background, symmetrical intricate motifs, leaves, flowers, 8k", ] - return prompts + negative_prompt = "bad composition, ugly, abnormal, malformed" - -class CudaMemoryMonitor: - def __init__(self, keep_measuring=True): - self.keep_measuring = keep_measuring - - def measure_gpu_usage(self): - from py3nvml.py3nvml import ( - NVMLError, - nvmlDeviceGetCount, - nvmlDeviceGetHandleByIndex, - nvmlDeviceGetMemoryInfo, - nvmlDeviceGetName, - nvmlInit, - nvmlShutdown, - ) - - max_gpu_usage = [] - gpu_name = [] - try: - nvmlInit() - device_count = nvmlDeviceGetCount() - if not isinstance(device_count, int): - print(f"nvmlDeviceGetCount result is not integer: {device_count}") - return None - - max_gpu_usage = [0 for i in range(device_count)] - gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)] - while True: - for i in range(device_count): - info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i)) - if isinstance(info, str): - print(f"nvmlDeviceGetMemoryInfo returns str: {info}") - return None - max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2) - time.sleep(0.002) # 2ms - if not self.keep_measuring: - break - nvmlShutdown() - return [ - { - "device_id": i, - "name": gpu_name[i], - "max_used_MB": max_gpu_usage[i], - } - for i in range(device_count) - ] - except NVMLError as error: - print("Error fetching GPU information using nvml: %s", error) - return None - - -class RocmMemoryMonitor: - def __init__(self, keep_measuring=True): - self.keep_measuring = keep_measuring - rocm_smi_path = "/opt/rocm/libexec/rocm_smi" - if os.path.exists(rocm_smi_path): - if rocm_smi_path not in sys.path: - sys.path.append(rocm_smi_path) - try: - import rocm_smi - - self.rocm_smi = rocm_smi - self.rocm_smi.initializeRsmi() - except ImportError: - self.rocm_smi = None - - def get_used_memory(self, dev): - if self.rocm_smi is None: - return -1 - return self.rocm_smi.getMemInfo(dev, "VRAM")[0] / 1024 / 1024 - - def measure_gpu_usage(self): - device_count = len(self.rocm_smi.listDevices()) if self.rocm_smi is not None else 0 - max_gpu_usage = [0 for i in range(device_count)] - gpu_name = [f"GPU{i}" for i in range(device_count)] - while True: - for i in range(device_count): - max_gpu_usage[i] = max(max_gpu_usage[i], self.get_used_memory(i)) - time.sleep(0.002) # 2ms - if not self.keep_measuring: - break - return [ - { - "device_id": i, - "name": gpu_name[i], - "max_used_MB": max_gpu_usage[i], - } - for i in range(device_count) - ] + return prompts, negative_prompt def measure_gpu_memory(monitor_type, func, start_memory=None): - if monitor_type is None: - return None - - monitor = monitor_type(False) - memory_before_test = monitor.measure_gpu_usage() - - if start_memory is None: - start_memory = memory_before_test - if start_memory is None: - return None - if func is None: - return start_memory - - from concurrent.futures import ThreadPoolExecutor - - with ThreadPoolExecutor() as executor: - monitor = monitor_type() - mem_thread = executor.submit(monitor.measure_gpu_usage) - try: - fn_thread = executor.submit(func) - _ = fn_thread.result() - finally: - monitor.keep_measuring = False - max_usage = mem_thread.result() - - if max_usage is None: - return None - - print(f"GPU memory usage: before={memory_before_test} peak={max_usage}") - if len(start_memory) >= 1 and len(max_usage) >= 1 and len(start_memory) == len(max_usage): - # When there are multiple GPUs, we will check the one with maximum usage. - max_used = 0 - for i, memory_before in enumerate(start_memory): - before = memory_before["max_used_MB"] - after = max_usage[i]["max_used_MB"] - used = after - before - max_used = max(max_used, used) - return max_used - return None + return measure_memory(is_gpu=True, func=func, monitor_type=monitor_type, start_memory=start_memory) def get_ort_pipeline(model_name: str, directory: str, provider, disable_safety_checker: bool): @@ -256,7 +133,7 @@ def run_ort_pipeline( assert isinstance(pipe, OnnxStableDiffusionPipeline) - prompts = example_prompts() + prompts, negative_prompt = example_prompts() def warmup(): pipe("warm up", height, width, num_inference_steps=steps, num_images_per_prompt=batch_size) @@ -275,13 +152,12 @@ def warmup(): for j in range(batch_count): inference_start = time.time() images = pipe( - prompt, + [prompt] * batch_size, height, width, num_inference_steps=steps, - negative_prompt=None, + negative_prompt=[negative_prompt] * batch_size, guidance_scale=7.5, - num_images_per_prompt=batch_size, ).images inference_end = time.time() latency = inference_end - inference_start @@ -320,7 +196,7 @@ def run_torch_pipeline( start_memory, memory_monitor_type, ): - prompts = example_prompts() + prompts, negative_prompt = example_prompts() # total 2 runs of warm up, and measure GPU memory for CUDA EP def warmup(): @@ -342,13 +218,12 @@ def warmup(): for j in range(batch_count): inference_start = time.time() images = pipe( - prompt=prompt, + prompt=[prompt] * batch_size, height=height, width=width, num_inference_steps=steps, guidance_scale=7.5, - negative_prompt=None, - num_images_per_prompt=batch_size, + negative_prompt=[negative_prompt] * batch_size, generator=None, # torch.Generator ).images @@ -427,7 +302,7 @@ def run_ort( def export_and_run_ort( - model_name: str, + version: str, provider: str, batch_size: int, disable_safety_checker: bool, @@ -443,15 +318,19 @@ def export_and_run_ort( assert provider == "CUDAExecutionProvider" from diffusers import DDIMScheduler + from diffusion_models import PipelineInfo from onnxruntime_cuda_txt2img import OnnxruntimeCudaStableDiffusionPipeline - scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") + pipeline_info = PipelineInfo(version) + model_name = pipeline_info.name() + scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") pipe = OnnxruntimeCudaStableDiffusionPipeline.from_pretrained( model_name, scheduler=scheduler, requires_safety_checker=not disable_safety_checker, enable_cuda_graph=enable_cuda_graph, + pipeline_info=pipeline_info, ) # re-use cached folder to save ONNX models @@ -473,7 +352,7 @@ def warmup(): image_filename_prefix = get_image_filename_prefix("ort_cuda", model_name, batch_size, disable_safety_checker) latency_list = [] - prompts = example_prompts() + prompts, negative_prompt = example_prompts() for i, prompt in enumerate(prompts): if i >= num_prompts: break @@ -481,6 +360,7 @@ def warmup(): inference_start = time.time() images = pipe( [prompt] * batch_size, + negative_prompt=[negative_prompt] * batch_size, num_inference_steps=steps, ).images inference_end = time.time() @@ -514,7 +394,7 @@ def warmup(): def run_ort_trt( - model_name: str, + version: str, batch_size: int, disable_safety_checker: bool, height: int, @@ -528,8 +408,12 @@ def run_ort_trt( enable_cuda_graph: bool, ): from diffusers import DDIMScheduler + from diffusion_models import PipelineInfo from onnxruntime_tensorrt_txt2img import OnnxruntimeTensorRTStableDiffusionPipeline + pipeline_info = PipelineInfo(version) + model_name = pipeline_info.name() + assert batch_size <= max_batch_size scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") @@ -544,6 +428,7 @@ def run_ort_trt( max_batch_size=max_batch_size, onnx_opset=17, enable_cuda_graph=enable_cuda_graph, + pipeline_info=pipeline_info, ) # re-use cached folder to save ONNX models and TensorRT Engines @@ -552,7 +437,7 @@ def run_ort_trt( pipe = pipe.to("cuda") def warmup(): - pipe(["warm up"] * batch_size, num_inference_steps=steps) + pipe(["warm up"] * batch_size, negative_prompt=["negative"] * batch_size, num_inference_steps=steps) # Run warm up, and measure GPU memory of two runs # The first run has algo search so it might need more memory @@ -564,7 +449,7 @@ def warmup(): image_filename_prefix = get_image_filename_prefix("ort_trt", model_name, batch_size, disable_safety_checker) latency_list = [] - prompts = example_prompts() + prompts, negative_prompt = example_prompts() for i, prompt in enumerate(prompts): if i >= num_prompts: break @@ -572,6 +457,7 @@ def warmup(): inference_start = time.time() images = pipe( [prompt] * batch_size, + negative_prompt=[negative_prompt] * batch_size, num_inference_steps=steps, ).images inference_end = time.time() @@ -589,7 +475,7 @@ def warmup(): "model_name": model_name, "engine": "onnxruntime", "version": ort_version, - "provider": f"tensorrt{trt_version})", + "provider": f"tensorrt({trt_version})", "directory": pipe.engine_dir, "height": height, "width": width, @@ -606,7 +492,148 @@ def warmup(): } -def run_tensorrt( +def run_ort_trt_static( + work_dir: str, + version: str, + batch_size: int, + disable_safety_checker: bool, + height: int, + width: int, + steps: int, + num_prompts: int, + batch_count: int, + start_memory, + memory_monitor_type, + max_batch_size: int, + nvtx_profile: bool = False, + use_cuda_graph: bool = True, +): + print("[I] Initializing ORT TensorRT EP accelerated StableDiffusionXL txt2img pipeline (static input shape)") + + # Register TensorRT plugins + from trt_utilities import init_trt_plugins + + init_trt_plugins() + + assert batch_size <= max_batch_size + + from diffusion_models import PipelineInfo + + pipeline_info = PipelineInfo(version) + short_name = pipeline_info.short_name() + + from engine_builder import EngineType, get_engine_paths + from pipeline_txt2img import Txt2ImgPipeline + + engine_type = EngineType.ORT_TRT + onnx_dir, engine_dir, output_dir, framework_model_dir, _ = get_engine_paths(work_dir, pipeline_info, engine_type) + + # Initialize pipeline + pipeline = Txt2ImgPipeline( + pipeline_info, + scheduler="DDIM", + output_dir=output_dir, + hf_token=None, + verbose=False, + nvtx_profile=nvtx_profile, + max_batch_size=max_batch_size, + use_cuda_graph=use_cuda_graph, + framework_model_dir=framework_model_dir, + engine_type=engine_type, + ) + + # Load TensorRT engines and pytorch modules + pipeline.backend.build_engines( + engine_dir, + framework_model_dir, + onnx_dir, + 17, + opt_image_height=height, + opt_image_width=width, + opt_batch_size=batch_size, + force_engine_rebuild=False, + static_batch=True, + static_image_shape=True, + max_workspace_size=0, + device_id=torch.cuda.current_device(), + ) + + # Here we use static batch and image size, so the resource allocation only need done once. + # For dynamic batch and image size, some cost (like memory allocation) shall be included in latency. + pipeline.load_resources(height, width, batch_size) + + def warmup(): + pipeline.run( + ["warm up"] * batch_size, ["negative"] * batch_size, height, width, denoising_steps=steps, warmup=True + ) + + # Run warm up, and measure GPU memory of two runs + # The first run has algo search so it might need more memory + first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + + warmup() + + image_filename_prefix = get_image_filename_prefix("ort_trt", short_name, batch_size, disable_safety_checker) + + latency_list = [] + prompts, negative_prompt = example_prompts() + for i, prompt in enumerate(prompts): + if i >= num_prompts: + break + for j in range(batch_count): + inference_start = time.time() + # Use warmup mode here since non-warmup mode will save image to disk. + images, pipeline_time = pipeline.run( + [prompt] * batch_size, + [negative_prompt] * batch_size, + height, + width, + denoising_steps=steps, + guidance=7.5, + seed=123, + warmup=True, + ) + images = pipeline.to_pil_image( + images + ) # include image conversion time to pil image for apple-to-apple compare + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time:.1f} ms") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") + + pipeline.teardown() + + from tensorrt import __version__ as trt_version + + from onnxruntime import __version__ as ort_version + + return { + "model_name": pipeline_info.name(), + "engine": "onnxruntime", + "version": ort_version, + "provider": f"tensorrt({trt_version})", + "directory": engine_dir, + "height": height, + "width": width, + "steps": steps, + "batch_size": batch_size, + "batch_count": batch_count, + "num_prompts": num_prompts, + "average_latency": sum(latency_list) / len(latency_list), + "median_latency": statistics.median(latency_list), + "first_run_memory_MB": first_run_memory, + "second_run_memory_MB": second_run_memory, + "disable_safety_checker": disable_safety_checker, + "enable_cuda_graph": use_cuda_graph, + } + + +def run_tensorrt_static( + work_dir: str, + version: str, model_name: str, batch_size: int, disable_safety_checker: bool, @@ -618,32 +645,79 @@ def run_tensorrt( start_memory, memory_monitor_type, max_batch_size: int, + nvtx_profile: bool = False, + use_cuda_graph: bool = True, ): - from diffusers import DDIMScheduler - from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline + print("[I] Initializing TensorRT accelerated StableDiffusionXL txt2img pipeline (static input shape)") + + from cuda import cudart + + # Register TensorRT plugins + from trt_utilities import init_trt_plugins + + init_trt_plugins() assert batch_size <= max_batch_size - scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") - pipe = StableDiffusionPipeline.from_pretrained( - model_name, - custom_pipeline="stable_diffusion_tensorrt_txt2img", - revision="fp16", - torch_dtype=torch.float16, - scheduler=scheduler, - requires_safety_checker=not disable_safety_checker, - image_height=height, - image_width=width, + from diffusion_models import PipelineInfo + + pipeline_info = PipelineInfo(version) + + from engine_builder import EngineType, get_engine_paths + from pipeline_txt2img import Txt2ImgPipeline + + engine_type = EngineType.TRT + onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( + work_dir, pipeline_info, engine_type + ) + + # Initialize pipeline + pipeline = Txt2ImgPipeline( + pipeline_info, + scheduler="DDIM", + output_dir=output_dir, + hf_token=None, + verbose=False, + nvtx_profile=nvtx_profile, max_batch_size=max_batch_size, + use_cuda_graph=True, + engine_type=engine_type, ) - # re-use cached folder to save ONNX models and TensorRT Engines - pipe.set_cached_folder(model_name, revision="fp16") + # Load TensorRT engines and pytorch modules + pipeline.backend.load_engines( + engine_dir=engine_dir, + framework_model_dir=framework_model_dir, + onnx_dir=onnx_dir, + onnx_opset=17, + opt_batch_size=batch_size, + opt_image_height=height, + opt_image_width=width, + force_export=False, + force_optimize=False, + force_build=False, + static_batch=True, + static_shape=True, + enable_refit=False, + enable_preview=False, + enable_all_tactics=False, + timing_cache=timing_cache, + onnx_refit_dir=None, + ) - pipe = pipe.to("cuda") + # activate engines + max_device_memory = max(pipeline.backend.max_device_memory(), pipeline.backend.max_device_memory()) + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + pipeline.backend.activate_engines(shared_device_memory) + + # Here we use static batch and image size, so the resource allocation only need done once. + # For dynamic batch and image size, some cost (like memory allocation) shall be included in latency. + pipeline.load_resources(height, width, batch_size) def warmup(): - pipe(["warm up"] * batch_size, num_inference_steps=steps) + pipeline.run( + ["warm up"] * batch_size, ["negative"] * batch_size, height, width, denoising_steps=steps, warmup=True + ) # Run warm up, and measure GPU memory of two runs # The first run has algo search so it might need more memory @@ -655,28 +729,225 @@ def warmup(): image_filename_prefix = get_image_filename_prefix("trt", model_name, batch_size, disable_safety_checker) latency_list = [] - prompts = example_prompts() + prompts, negative_prompt = example_prompts() for i, prompt in enumerate(prompts): if i >= num_prompts: break for j in range(batch_count): inference_start = time.time() - images = pipe( + # Use warmup mode here since non-warmup mode will save image to disk. + images, pipeline_time = pipeline.run( [prompt] * batch_size, - num_inference_steps=steps, - ).images + [negative_prompt] * batch_size, + height, + width, + denoising_steps=steps, + guidance=7.5, + seed=123, + warmup=True, + ) + images = pipeline.to_pil_image( + images + ) # include image conversion time to pil image for apple-to-apple compare inference_end = time.time() latency = inference_end - inference_start latency_list.append(latency) - print(f"Inference took {latency:.3f} seconds") + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time:.1f} ms") for k, image in enumerate(images): image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") - from tensorrt import __version__ as trt_version + pipeline.teardown() + + import tensorrt as trt + + return { + "engine": "tensorrt", + "version": trt.__version__, + "provider": "default", + "height": height, + "width": width, + "steps": steps, + "batch_size": batch_size, + "batch_count": batch_count, + "num_prompts": num_prompts, + "average_latency": sum(latency_list) / len(latency_list), + "median_latency": statistics.median(latency_list), + "first_run_memory_MB": first_run_memory, + "second_run_memory_MB": second_run_memory, + "enable_cuda_graph": use_cuda_graph, + } + + +def run_tensorrt_static_xl( + work_dir: str, + version: str, + batch_size: int, + disable_safety_checker: bool, + height: int, + width: int, + steps: int, + num_prompts: int, + batch_count: int, + start_memory, + memory_monitor_type, + max_batch_size: int, + nvtx_profile: bool = False, + use_cuda_graph=True, +): + print("[I] Initializing TensorRT accelerated StableDiffusionXL txt2img pipeline (static input shape)") + + import tensorrt as trt + from cuda import cudart + from trt_utilities import init_trt_plugins + + # Validate image dimensions + image_height = height + image_width = width + if image_height % 8 != 0 or image_width % 8 != 0: + raise ValueError( + f"Image height and width have to be divisible by 8 but specified as: {image_height} and {image_width}." + ) + + # Register TensorRT plugins + init_trt_plugins() + + assert batch_size <= max_batch_size + + from diffusion_models import PipelineInfo + from engine_builder import EngineType, get_engine_paths + + def init_pipeline(pipeline_class, pipeline_info): + engine_type = EngineType.TRT + + onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( + work_dir, pipeline_info, engine_type + ) + + # Initialize pipeline + pipeline = pipeline_class( + pipeline_info, + scheduler="DDIM", + output_dir=output_dir, + hf_token=None, + verbose=False, + nvtx_profile=nvtx_profile, + max_batch_size=max_batch_size, + use_cuda_graph=use_cuda_graph, + framework_model_dir=framework_model_dir, + engine_type=engine_type, + ) + + pipeline.backend.load_engines( + engine_dir=engine_dir, + framework_model_dir=framework_model_dir, + onnx_dir=onnx_dir, + onnx_opset=17, + opt_batch_size=batch_size, + opt_image_height=height, + opt_image_width=width, + force_export=False, + force_optimize=False, + force_build=False, + static_batch=True, + static_shape=True, + enable_refit=False, + enable_preview=False, + enable_all_tactics=False, + timing_cache=timing_cache, + onnx_refit_dir=None, + ) + return pipeline + + from pipeline_img2img_xl import Img2ImgXLPipeline + from pipeline_txt2img_xl import Txt2ImgXLPipeline + + base_pipeline_info = PipelineInfo(version) + demo_base = init_pipeline(Txt2ImgXLPipeline, base_pipeline_info) + + refiner_pipeline_info = PipelineInfo(version, is_refiner=True) + demo_refiner = init_pipeline(Img2ImgXLPipeline, refiner_pipeline_info) + + max_device_memory = max(demo_base.backend.max_device_memory(), demo_refiner.backend.max_device_memory()) + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + demo_base.backend.activate_engines(shared_device_memory) + demo_refiner.backend.activate_engines(shared_device_memory) + + # Here we use static batch and image size, so the resource allocation only need done once. + # For dynamic batch and image size, some cost (like memory allocation) shall be included in latency. + demo_base.load_resources(image_height, image_width, batch_size) + demo_refiner.load_resources(image_height, image_width, batch_size) + + def run_sd_xl_inference(prompt, negative_prompt, seed=None, warmup=False): + images, time_base = demo_base.run( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=steps, + guidance=5.0, + warmup=warmup, + seed=seed, + return_type="latent", + ) + + images, time_refiner = demo_refiner.run( + prompt, + negative_prompt, + images, + image_height, + image_width, + denoising_steps=steps, + guidance=5.0, + warmup=warmup, + seed=seed, + ) + return images, time_base + time_refiner + + def warmup(): + run_sd_xl_inference(["warm up"] * batch_size, ["negative"] * batch_size, warmup=True) + + # Run warm up, and measure GPU memory of two runs + # The first run has algo search so it might need more memory + first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + + warmup() + + model_name = refiner_pipeline_info.name() + image_filename_prefix = get_image_filename_prefix("trt", model_name, batch_size, disable_safety_checker) + + latency_list = [] + prompts, negative_prompt = example_prompts() + for i, prompt in enumerate(prompts): + if i >= num_prompts: + break + for j in range(batch_count): + inference_start = time.time() + # Use warmup mode here since non-warmup mode will save image to disk. + if nvtx_profile: + cudart.cudaProfilerStart() + images, pipeline_time = run_sd_xl_inference( + [prompt] * batch_size, [negative_prompt] * batch_size, seed=123, warmup=True + ) + if nvtx_profile: + cudart.cudaProfilerStop() + images = demo_refiner.to_pil_image( + images + ) # include image conversion time to pil image for apple-to-apple compare + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time:.1f} ms") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{j}_{k}.png") + + demo_base.teardown() + demo_refiner.teardown() return { + "model_name": model_name, "engine": "tensorrt", - "version": trt_version, + "version": trt.__version__, "provider": "default", "height": height, "width": width, @@ -688,7 +959,178 @@ def warmup(): "median_latency": statistics.median(latency_list), "first_run_memory_MB": first_run_memory, "second_run_memory_MB": second_run_memory, - "enable_cuda_graph": False, + "enable_cuda_graph": use_cuda_graph, + } + + +def run_ort_trt_xl( + work_dir: str, + version: str, + batch_size: int, + disable_safety_checker: bool, + height: int, + width: int, + steps: int, + num_prompts: int, + batch_count: int, + start_memory, + memory_monitor_type, + max_batch_size: int, + nvtx_profile: bool = False, + use_cuda_graph=True, +): + from cuda import cudart + + # Validate image dimensions + image_height = height + image_width = width + if image_height % 8 != 0 or image_width % 8 != 0: + raise ValueError( + f"Image height and width have to be divisible by 8 but specified as: {image_height} and {image_width}." + ) + + assert batch_size <= max_batch_size + + from engine_builder import EngineType, get_engine_paths + + def init_pipeline(pipeline_class, pipeline_info): + engine_type = EngineType.ORT_TRT + + onnx_dir, engine_dir, output_dir, framework_model_dir, _ = get_engine_paths( + work_dir, pipeline_info, engine_type + ) + + # Initialize pipeline + pipeline = pipeline_class( + pipeline_info, + scheduler="DDIM", + output_dir=output_dir, + hf_token=None, + verbose=False, + nvtx_profile=nvtx_profile, + max_batch_size=max_batch_size, + use_cuda_graph=use_cuda_graph, + framework_model_dir=framework_model_dir, + engine_type=engine_type, + ) + + pipeline.backend.build_engines( + engine_dir, + framework_model_dir, + onnx_dir, + 17, + opt_image_height=height, + opt_image_width=width, + opt_batch_size=batch_size, + force_engine_rebuild=False, + static_batch=True, + static_image_shape=True, + max_workspace_size=0, + device_id=torch.cuda.current_device(), # TODO: might not work with CUDA_VISIBLE_DEVICES + ) + return pipeline + + from diffusion_models import PipelineInfo + from pipeline_img2img_xl import Img2ImgXLPipeline + from pipeline_txt2img_xl import Txt2ImgXLPipeline + + base_pipeline_info = PipelineInfo(version) + demo_base = init_pipeline(Txt2ImgXLPipeline, base_pipeline_info) + + refiner_pipeline_info = PipelineInfo(version, is_refiner=True) + demo_refiner = init_pipeline(Img2ImgXLPipeline, refiner_pipeline_info) + + demo_base.load_resources(image_height, image_width, batch_size) + demo_refiner.load_resources(image_height, image_width, batch_size) + + def run_sd_xl_inference(prompt, negative_prompt, seed=None, warmup=False): + images, time_base = demo_base.run( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=steps, + guidance=5.0, + warmup=warmup, + seed=seed, + return_type="latent", + ) + images, time_refiner = demo_refiner.run( + prompt, + negative_prompt, + images, + image_height, + image_width, + denoising_steps=steps, + guidance=5.0, + warmup=warmup, + seed=seed, + ) + return images, time_base + time_refiner + + def warmup(): + run_sd_xl_inference(["warm up"] * batch_size, ["negative"] * batch_size, warmup=True) + + # Run warm up, and measure GPU memory of two runs + # The first run has algo search so it might need more memory + first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + + warmup() + + model_name = refiner_pipeline_info.name() + image_filename_prefix = get_image_filename_prefix("ort_trt", model_name, batch_size, disable_safety_checker) + + latency_list = [] + prompts, negative_prompt = example_prompts() + for i, prompt in enumerate(prompts): + if i >= num_prompts: + break + for j in range(batch_count): + inference_start = time.time() + # Use warmup mode here since non-warmup mode will save image to disk. + if nvtx_profile: + cudart.cudaProfilerStart() + images, pipeline_time = run_sd_xl_inference( + [prompt] * batch_size, [negative_prompt] * batch_size, seed=123, warmup=True + ) + if nvtx_profile: + cudart.cudaProfilerStop() + images = demo_refiner.to_pil_image( + images + ) # include image conversion time to pil image for apple-to-apple compare + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time:.1f} ms") + for k, image in enumerate(images): + filename = f"{image_filename_prefix}_{i}_{j}_{k}.png" + image.save(filename) + print("Image saved to", filename) + + demo_base.teardown() + demo_refiner.teardown() + + from tensorrt import __version__ as trt_version + + from onnxruntime import __version__ as ort_version + + return { + "model_name": model_name, + "engine": "onnxruntime", + "version": ort_version, + "provider": f"tensorrt{trt_version})", + "height": height, + "width": width, + "steps": steps, + "batch_size": batch_size, + "batch_count": batch_count, + "num_prompts": num_prompts, + "average_latency": sum(latency_list) / len(latency_list), + "median_latency": statistics.median(latency_list), + "first_run_memory_MB": first_run_memory, + "second_run_memory_MB": second_run_memory, + "enable_cuda_graph": use_cuda_graph, } @@ -808,6 +1250,15 @@ def parse_arguments(): help="Directory of saved onnx pipeline. It could be the output directory of optimize_pipeline.py.", ) + parser.add_argument( + "-w", + "--work_dir", + required=False, + type=str, + default=".", + help="Root directory to save exported onnx models, built engines etc.", + ) + parser.add_argument( "--enable_safety_checker", required=False, @@ -922,28 +1373,31 @@ def main(): args = parse_arguments() print(args) - if args.enable_cuda_graph: - if not (args.engine == "onnxruntime" and args.provider in ["cuda", "tensorrt"] and args.pipeline is None): - raise ValueError("The stable diffusion pipeline does not support CUDA graph.") + if args.engine == "onnxruntime": + if args.version in ["2.1"]: + # Set a flag to avoid overflow in attention, which causes black image output in SD 2.1 model. + # The environment variables shall be set before the first run of Attention or MultiHeadAttention operator. + os.environ["ORT_DISABLE_TRT_FLASH_ATTENTION"] = "1" from packaging import version from onnxruntime import __version__ as ort_version - if version.parse(ort_version) < version.parse("1.16"): - raise ValueError( - "CUDA graph requires ONNX Runtime 1.16. You can install nightly like the following:\n" - " pip uninstall onnxruntime-gpu\n" - " pip install ort-nightly-gpu -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/" - ) + if version.parse(ort_version) == version.parse("1.16.0"): + # ORT 1.16 has a bug that might trigger Attention RuntimeError when latest fusion script is applied on clip model. + # The walkaround is to enable fused causal attention, or disable Attention fusion for clip model. + os.environ["ORT_ENABLE_FUSED_CAUSAL_ATTENTION"] = "1" + + if args.enable_cuda_graph: + if not (args.engine == "onnxruntime" and args.provider in ["cuda", "tensorrt"] and args.pipeline is None): + raise ValueError("The stable diffusion pipeline does not support CUDA graph.") + + if version.parse(ort_version) < version.parse("1.16"): + raise ValueError("CUDA graph requires ONNX Runtime 1.16 or later") coloredlogs.install(fmt="%(funcName)20s: %(message)s") - memory_monitor_type = None - if args.provider in ["cuda", "tensorrt"]: - memory_monitor_type = CudaMemoryMonitor - elif args.provider == "rocm": - memory_monitor_type = RocmMemoryMonitor + memory_monitor_type = "rocm" if args.provider == "rocm" else "cuda" start_memory = measure_gpu_memory(memory_monitor_type, None) print("GPU memory used before loading models:", start_memory) @@ -951,89 +1405,157 @@ def main(): sd_model = SD_MODELS[args.version] provider = PROVIDERS[args.provider] if args.engine == "onnxruntime" and args.provider == "tensorrt": - result = run_ort_trt( - sd_model, - args.batch_size, - not args.enable_safety_checker, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.max_trt_batch_size, - args.enable_cuda_graph, - ) + if "xl" in args.version: + print("Testing Txt2ImgXLPipeline with static input shape. Backend is ORT TensorRT EP.") + result = run_ort_trt_xl( + work_dir=args.work_dir, + version=args.version, + batch_size=args.batch_size, + disable_safety_checker=True, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, + nvtx_profile=False, + use_cuda_graph=args.enable_cuda_graph, + ) + elif args.tuning: + print( + "Testing OnnxruntimeTensorRTStableDiffusionPipeline with {}.".format( + "static input shape" if args.enable_cuda_graph else "dynamic batch size" + ) + ) + result = run_ort_trt( + version=args.version, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, + enable_cuda_graph=args.enable_cuda_graph, + ) + else: + print("Testing Txt2ImgPipeline with static input shape. Backend is ORT TensorRT EP.") + result = run_ort_trt_static( + work_dir=args.work_dir, + version=args.version, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, + nvtx_profile=False, + use_cuda_graph=args.enable_cuda_graph, + ) + elif args.engine == "onnxruntime" and provider == "CUDAExecutionProvider" and args.pipeline is None: - print("Pipeline is not specified. Trying export and optimize onnx models...") + print( + "Testing OnnxruntimeCudaStableDiffusionPipeline with {} input shape. Backend is ORT CUDA EP.".format( + "static" if args.enable_cuda_graph else "dynamic" + ) + ) result = export_and_run_ort( - sd_model, - provider, - args.batch_size, - not args.enable_safety_checker, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.enable_cuda_graph, + version=args.version, + provider=provider, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + enable_cuda_graph=args.enable_cuda_graph, ) elif args.engine == "onnxruntime": assert args.pipeline and os.path.isdir( args.pipeline ), "--pipeline should be specified for the directory of ONNX models" - - if args.version in ["2.1"]: - # Set a flag to avoid overflow in attention, which causes black image output in SD 2.1 model - # This shall be done before the first inference run. - os.environ["ORT_DISABLE_TRT_FLASH_ATTENTION"] = "1" - + print(f"Testing diffusers StableDiffusionPipeline with {provider} provider and tuning={args.tuning}") result = run_ort( - sd_model, - args.pipeline, - provider, - args.batch_size, - not args.enable_safety_checker, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.tuning, + model_name=sd_model, + directory=args.pipeline, + provider=provider, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + tuning=args.tuning, + ) + elif args.engine == "tensorrt" and "xl" in args.version: + print("Testing Txt2ImgXLPipeline with static input shape. Backend is TensorRT.") + result = run_tensorrt_static_xl( + work_dir=args.work_dir, + version=args.version, + batch_size=args.batch_size, + disable_safety_checker=True, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, + nvtx_profile=False, + use_cuda_graph=args.enable_cuda_graph, ) elif args.engine == "tensorrt": - result = run_tensorrt( - sd_model, - args.batch_size, - not args.enable_safety_checker, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.max_trt_batch_size, + print("Testing Txt2ImgPipeline with static input shape. Backend is TensorRT.") + result = run_tensorrt_static( + work_dir=args.work_dir, + version=args.version, + model_name=sd_model, + batch_size=args.batch_size, + disable_safety_checker=True, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, + nvtx_profile=False, + use_cuda_graph=args.enable_cuda_graph, ) else: + print( + f"Testing Txt2ImgPipeline with dynamic input shape. Backend is PyTorch: compile={args.enable_torch_compile}, xformers={args.use_xformers}." + ) result = run_torch( - sd_model, - args.batch_size, - not args.enable_safety_checker, - args.enable_torch_compile, - args.use_xformers, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, + model_name=sd_model, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + enable_torch_compile=args.enable_torch_compile, + use_xformers=args.use_xformers, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, ) print(result) @@ -1068,8 +1590,9 @@ def main(): if __name__ == "__main__": + import traceback + try: main() - except Exception as e: - tb = sys.exc_info() - print(e.with_traceback(tb[2])) + except Exception: + traceback.print_exception(*sys.exc_info()) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py new file mode 100644 index 0000000000..fb051ac1ed --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py @@ -0,0 +1,97 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import coloredlogs +from cuda import cudart +from demo_utils import init_pipeline, parse_arguments, repeat_prompt +from diffusion_models import PipelineInfo +from engine_builder import EngineType, get_engine_type +from pipeline_txt2img import Txt2ImgPipeline + +if __name__ == "__main__": + coloredlogs.install(fmt="%(funcName)20s: %(message)s") + + args = parse_arguments(is_xl=False, description="Options for Stable Diffusion Demo") + prompt, negative_prompt = repeat_prompt(args) + + image_height = args.height + image_width = args.width + + # Register TensorRT plugins + engine_type = get_engine_type(args.engine) + if engine_type == EngineType.TRT: + from trt_utilities import init_trt_plugins + + init_trt_plugins() + + max_batch_size = 16 + if engine_type != EngineType.ORT_CUDA and (args.build_dynamic_shape or image_height > 512 or image_width > 512): + max_batch_size = 4 + + batch_size = len(prompt) + if batch_size > max_batch_size: + raise ValueError( + f"Batch size {len(prompt)} is larger than allowed {max_batch_size}. If dynamic shape is used, then maximum batch size is 4" + ) + + pipeline_info = PipelineInfo(args.version) + pipeline = init_pipeline(Txt2ImgPipeline, pipeline_info, engine_type, args, max_batch_size, batch_size) + + if engine_type == EngineType.TRT: + max_device_memory = max(pipeline.backend.max_device_memory(), pipeline.backend.max_device_memory()) + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + pipeline.backend.activate_engines(shared_device_memory) + + if engine_type == EngineType.ORT_CUDA and args.enable_vae_slicing: + pipeline.backend.enable_vae_slicing() + + pipeline.load_resources(image_height, image_width, batch_size) + + def run_inference(warmup=False): + return pipeline.run( + prompt, + negative_prompt, + image_height, + image_width, + warmup=warmup, + denoising_steps=args.denoising_steps, + guidance=args.guidance, + seed=args.seed, + return_type="image", + ) + + if not args.disable_cuda_graph: + # inference once to get cuda graph + _image, _latency = run_inference(warmup=True) + + print("[I] Warming up ..") + for _ in range(args.num_warmup_runs): + _image, _latency = run_inference(warmup=True) + + print("[I] Running StableDiffusion pipeline") + if args.nvtx_profile: + cudart.cudaProfilerStart() + _image, _latency = run_inference(warmup=False) + if args.nvtx_profile: + cudart.cudaProfilerStop() + + pipeline.teardown() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py new file mode 100644 index 0000000000..16e776a082 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py @@ -0,0 +1,136 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import coloredlogs +from cuda import cudart +from demo_utils import init_pipeline, parse_arguments, repeat_prompt +from diffusion_models import PipelineInfo +from engine_builder import EngineType, get_engine_type +from pipeline_img2img_xl import Img2ImgXLPipeline +from pipeline_txt2img_xl import Txt2ImgXLPipeline + + +def run_demo(): + """Run Stable Diffusion XL Base + Refiner together (known as ensemble of expert denoisers) to generate an image.""" + + args = parse_arguments(is_xl=True, description="Options for Stable Diffusion XL Demo") + + prompt, negative_prompt = repeat_prompt(args) + + # Recommend image size as one of those used in training (see Appendix I in https://arxiv.org/pdf/2307.01952.pdf). + image_height = args.height + image_width = args.width + + # Register TensorRT plugins + engine_type = get_engine_type(args.engine) + if engine_type == EngineType.TRT: + from trt_utilities import init_trt_plugins + + init_trt_plugins() + + max_batch_size = 16 + if (engine_type in [EngineType.ORT_TRT, EngineType.TRT]) and ( + args.build_dynamic_shape or image_height > 512 or image_width > 512 + ): + max_batch_size = 4 + + batch_size = len(prompt) + if batch_size > max_batch_size: + raise ValueError(f"Batch size {batch_size} is larger than allowed {max_batch_size}.") + + # No VAE decoder in base when it outputs latent instead of image. + base_info = PipelineInfo(args.version, use_vae=False) + base = init_pipeline(Txt2ImgXLPipeline, base_info, engine_type, args, max_batch_size, batch_size) + + refiner_info = PipelineInfo(args.version, is_refiner=True) + refiner = init_pipeline(Img2ImgXLPipeline, refiner_info, engine_type, args, max_batch_size, batch_size) + + if engine_type == EngineType.TRT: + max_device_memory = max(base.backend.max_device_memory(), refiner.backend.max_device_memory()) + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + base.backend.activate_engines(shared_device_memory) + refiner.backend.activate_engines(shared_device_memory) + + if engine_type == EngineType.ORT_CUDA: + enable_vae_slicing = args.enable_vae_slicing + if batch_size > 4 and not enable_vae_slicing: + print("Updating enable_vae_slicing to be True to avoid cuDNN error for batch size > 4.") + enable_vae_slicing = True + if enable_vae_slicing: + refiner.backend.enable_vae_slicing() + + base.load_resources(image_height, image_width, batch_size) + refiner.load_resources(image_height, image_width, batch_size) + + def run_base_and_refiner(warmup=False): + images, time_base = base.run( + prompt, + negative_prompt, + image_height, + image_width, + warmup=warmup, + denoising_steps=args.denoising_steps, + guidance=args.guidance, + seed=args.seed, + return_type="latent", + ) + + images, time_refiner = refiner.run( + prompt, + negative_prompt, + images, + image_height, + image_width, + warmup=warmup, + denoising_steps=args.denoising_steps, + guidance=args.guidance, + seed=args.seed, + ) + + return images, time_base + time_refiner + + if not args.disable_cuda_graph: + # inference once to get cuda graph + _, _ = run_base_and_refiner(warmup=True) + + print("[I] Warming up ..") + for _ in range(args.num_warmup_runs): + _, _ = run_base_and_refiner(warmup=True) + + print("[I] Running StableDiffusion XL pipeline") + if args.nvtx_profile: + cudart.cudaProfilerStart() + _, latency = run_base_and_refiner(warmup=False) + if args.nvtx_profile: + cudart.cudaProfilerStop() + + base.teardown() + + print("|------------|--------------|") + print("| {:^10} | {:>9.2f} ms |".format("e2e", latency)) + print("|------------|--------------|") + refiner.teardown() + + +if __name__ == "__main__": + coloredlogs.install(fmt="%(funcName)20s: %(message)s") + run_demo() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py new file mode 100644 index 0000000000..e65efd2c53 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py @@ -0,0 +1,281 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import argparse + +import torch +from diffusion_models import PipelineInfo +from engine_builder import EngineType, get_engine_paths + + +class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter): + pass + + +def parse_arguments(is_xl: bool, description: str): + parser = argparse.ArgumentParser(description=description, formatter_class=RawTextArgumentDefaultsHelpFormatter) + + engines = ["ORT_CUDA", "ORT_TRT", "TRT"] + + parser.add_argument( + "--engine", + type=str, + default=engines[0], + choices=engines, + help="Backend engine in {engines}. " + "ORT_CUDA is CUDA execution provider; ORT_TRT is Tensorrt execution provider; TRT is TensorRT", + ) + + supported_versions = PipelineInfo.supported_versions(is_xl) + parser.add_argument( + "--version", + type=str, + default=supported_versions[-1] if is_xl else "1.5", + choices=supported_versions, + help="Version of Stable Diffusion" + (" XL." if is_xl else "."), + ) + + parser.add_argument( + "--height", + type=int, + default=1024 if is_xl else 512, + help="Height of image to generate (must be multiple of 8).", + ) + parser.add_argument( + "--width", type=int, default=1024 if is_xl else 512, help="Height of image to generate (must be multiple of 8)." + ) + + parser.add_argument( + "--scheduler", + type=str, + default="DDIM", + choices=["DDIM", "UniPC"] if is_xl else ["DDIM", "EulerA", "UniPC"], + help="Scheduler for diffusion process", + ) + + parser.add_argument( + "--work-dir", + default=".", + help="Root Directory to store torch or ONNX models, built engines and output images etc.", + ) + + parser.add_argument("prompt", nargs="+", help="Text prompt(s) to guide image generation.") + + parser.add_argument( + "--negative-prompt", nargs="*", default=[""], help="Optional negative prompt(s) to guide the image generation." + ) + parser.add_argument( + "--repeat-prompt", + type=int, + default=1, + choices=[1, 2, 4, 8, 16], + help="Number of times to repeat the prompt (batch size multiplier).", + ) + + parser.add_argument( + "--denoising-steps", + type=int, + default=30 if is_xl else 50, + help="Number of denoising steps" + (" in base." if is_xl else "."), + ) + + parser.add_argument( + "--guidance", + type=float, + default=5.0 if is_xl else 7.5, + help="Higher guidance scale encourages to generate images that are closely linked to the text prompt.", + ) + + # ONNX export + parser.add_argument( + "--onnx-opset", + type=int, + default=None, + choices=range(14, 18), + help="Select ONNX opset version to target for exported models.", + ) + parser.add_argument( + "--force-onnx-export", action="store_true", help="Force ONNX export of CLIP, UNET, and VAE models." + ) + parser.add_argument( + "--force-onnx-optimize", action="store_true", help="Force ONNX optimizations for CLIP, UNET, and VAE models." + ) + + # Framework model ckpt + parser.add_argument( + "--framework-model-dir", + default="pytorch_model", + help="Directory for HF saved models. Default is pytorch_model.", + ) + parser.add_argument("--hf-token", type=str, help="HuggingFace API access token for downloading model checkpoints.") + + # Engine build options. + parser.add_argument("--force-engine-build", action="store_true", help="Force rebuilding the TensorRT engine.") + parser.add_argument( + "--build-dynamic-batch", action="store_true", help="Build TensorRT engines to support dynamic batch size." + ) + parser.add_argument( + "--build-dynamic-shape", action="store_true", help="Build TensorRT engines to support dynamic image sizes." + ) + + # Inference related options + parser.add_argument( + "--num-warmup-runs", type=int, default=5, help="Number of warmup runs before benchmarking performance." + ) + parser.add_argument("--nvtx-profile", action="store_true", help="Enable NVTX markers for performance profiling.") + parser.add_argument("--seed", type=int, default=None, help="Seed for random generator to get consistent results.") + parser.add_argument("--disable-cuda-graph", action="store_true", help="Disable cuda graph.") + + group = parser.add_argument_group("Options for ORT_CUDA engine only") + group.add_argument("--enable-vae-slicing", action="store_true", help="True will feed only one image to VAE once.") + + # TensorRT only options + group = parser.add_argument_group("Options for TensorRT (--engine=TRT) only") + group.add_argument("--onnx-refit-dir", help="ONNX models to load the weights from.") + group.add_argument( + "--build-enable-refit", action="store_true", help="Enable Refit option in TensorRT engines during build." + ) + group.add_argument( + "--build-preview-features", action="store_true", help="Build TensorRT engines with preview features." + ) + group.add_argument( + "--build-all-tactics", action="store_true", help="Build TensorRT engines using all tactic sources." + ) + + args = parser.parse_args() + + if ( + args.engine in ["ORT_CUDA", "ORT_TRT"] + and (args.force_onnx_export or args.force_onnx_optimize) + and not args.force_engine_build + ): + raise ValueError( + "For ORT_CUDA or ORT_TRT, --force_onnx_export and --force_onnx_optimize are not supported. " + "Please use --force_engine_build instead." + ) + + # Validate image dimensions + if args.height % 8 != 0 or args.width % 8 != 0: + raise ValueError( + f"Image height and width have to be divisible by 8 but specified as: {args.height} and {args.width}." + ) + + if (args.build_dynamic_batch or args.build_dynamic_shape) and not args.disable_cuda_graph: + print("[I] CUDA Graph is disabled since dynamic input shape is configured.") + args.disable_cuda_graph = True + + if args.onnx_opset is None: + args.onnx_opset = 14 if args.engine == "ORT_CUDA" else 17 + + print(args) + + return args + + +def repeat_prompt(args): + if not isinstance(args.prompt, list): + raise ValueError(f"`prompt` must be of type `str` or `str` list, but is {type(args.prompt)}") + prompt = args.prompt * args.repeat_prompt + + if not isinstance(args.negative_prompt, list): + raise ValueError( + f"`--negative-prompt` must be of type `str` or `str` list, but is {type(args.negative_prompt)}" + ) + + if len(args.negative_prompt) == 1: + negative_prompt = args.negative_prompt * len(prompt) + else: + negative_prompt = args.negative_prompt + + return prompt, negative_prompt + + +def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_size, batch_size): + onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( + work_dir=args.work_dir, pipeline_info=pipeline_info, engine_type=engine_type + ) + + # Initialize demo + pipeline = pipeline_class( + pipeline_info, + scheduler=args.scheduler, + output_dir=output_dir, + hf_token=args.hf_token, + verbose=False, + nvtx_profile=args.nvtx_profile, + max_batch_size=max_batch_size, + use_cuda_graph=not args.disable_cuda_graph, + framework_model_dir=framework_model_dir, + engine_type=engine_type, + ) + + if engine_type == EngineType.ORT_CUDA: + # Build CUDA EP engines and load pytorch modules + pipeline.backend.build_engines( + engine_dir=engine_dir, + framework_model_dir=framework_model_dir, + onnx_dir=onnx_dir, + opt_image_height=args.height, + opt_image_width=args.height, + opt_batch_size=batch_size, + force_engine_rebuild=args.force_engine_build, + device_id=torch.cuda.current_device(), + ) + elif engine_type == EngineType.ORT_TRT: + # Build TensorRT EP engines and load pytorch modules + pipeline.backend.build_engines( + engine_dir, + framework_model_dir, + onnx_dir, + args.onnx_opset, + opt_image_height=args.height, + opt_image_width=args.height, + opt_batch_size=batch_size, + force_engine_rebuild=args.force_engine_build, + static_batch=not args.build_dynamic_batch, + static_image_shape=not args.build_dynamic_shape, + max_workspace_size=0, + device_id=torch.cuda.current_device(), + ) + elif engine_type == EngineType.TRT: + # Load TensorRT engines and pytorch modules + pipeline.backend.load_engines( + engine_dir, + framework_model_dir, + onnx_dir, + args.onnx_opset, + opt_batch_size=batch_size, + opt_image_height=args.height, + opt_image_width=args.height, + force_export=args.force_onnx_export, + force_optimize=args.force_onnx_optimize, + force_build=args.force_engine_build, + static_batch=not args.build_dynamic_batch, + static_shape=not args.build_dynamic_shape, + enable_refit=args.build_enable_refit, + enable_preview=args.build_preview_features, + enable_all_tactics=args.build_all_tactics, + timing_cache=timing_cache, + onnx_refit_dir=args.onnx_refit_dir, + ) + + return pipeline diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py new file mode 100644 index 0000000000..4a2e9eb344 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py @@ -0,0 +1,960 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from stable_diffusion_tensorrt_txt2img.py in diffusers and TensorRT demo diffusion, +# which has the following license: +# +# Copyright 2023 The HuggingFace Inc. team. +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import tempfile +from typing import Dict, List, Optional + +import onnx +import onnx_graphsurgeon as gs +import torch +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from onnx import GraphProto, ModelProto, shape_inference +from ort_optimizer import OrtStableDiffusionOptimizer +from polygraphy.backend.onnx.loader import fold_constants +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from onnxruntime.transformers.onnx_model import OnnxModel + +logger = logging.getLogger(__name__) + + +class TrtOptimizer: + def __init__(self, onnx_graph): + self.graph = gs.import_onnx(onnx_graph) + + def cleanup(self): + self.graph.cleanup().toposort() + + def get_optimized_onnx_graph(self): + return gs.export_onnx(self.graph) + + def select_outputs(self, keep, names=None): + self.graph.outputs = [self.graph.outputs[o] for o in keep] + if names: + for i, name in enumerate(names): + self.graph.outputs[i].name = name + + def fold_constants(self): + onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True) + self.graph = gs.import_onnx(onnx_graph) + + def infer_shapes(self): + onnx_graph = gs.export_onnx(self.graph) + if onnx_graph.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF: + with tempfile.TemporaryDirectory() as temp_dir: + input_onnx_path = os.path.join(temp_dir, "model.onnx") + onnx.save_model( + onnx_graph, + input_onnx_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + convert_attribute=False, + ) + output_onnx_path = os.path.join(temp_dir, "model_with_shape.onnx") + onnx.shape_inference.infer_shapes_path(input_onnx_path, output_onnx_path) + onnx_graph = onnx.load(output_onnx_path) + else: + onnx_graph = shape_inference.infer_shapes(onnx_graph) + + self.graph = gs.import_onnx(onnx_graph) + + +class PipelineInfo: + def __init__(self, version: str, is_inpaint: bool = False, is_refiner: bool = False, use_vae=False): + self.version = version + self._is_inpaint = is_inpaint + self._is_refiner = is_refiner + self._use_vae = use_vae + + if is_refiner: + assert self.is_xl() + + def is_inpaint(self) -> bool: + return self._is_inpaint + + def is_xl(self) -> bool: + return "xl" in self.version + + def is_xl_base(self) -> bool: + return self.is_xl() and not self._is_refiner + + def is_xl_refiner(self) -> bool: + return self.is_xl() and self._is_refiner + + def use_safetensors(self) -> bool: + return self.is_xl() + + def stages(self) -> List[str]: + if self.is_xl_base(): + return ["clip", "clip2", "unetxl"] + (["vae"] if self._use_vae else []) + + if self.is_xl_refiner(): + return ["clip2", "unetxl", "vae"] + + return ["clip", "unet", "vae"] + + def vae_scaling_factor(self) -> float: + return 0.13025 if self.is_xl() else 0.18215 + + @staticmethod + def supported_versions(is_xl: bool): + return ["xl-1.0"] if is_xl else ["1.4", "1.5", "2.0-base", "2.0", "2.1", "2.1-base"] + + def name(self) -> str: + if self.version == "1.4": + if self.is_inpaint(): + return "runwayml/stable-diffusion-inpainting" + else: + return "CompVis/stable-diffusion-v1-4" + elif self.version == "1.5": + if self.is_inpaint(): + return "runwayml/stable-diffusion-inpainting" + else: + return "runwayml/stable-diffusion-v1-5" + elif self.version == "2.0-base": + if self.is_inpaint(): + return "stabilityai/stable-diffusion-2-inpainting" + else: + return "stabilityai/stable-diffusion-2-base" + elif self.version == "2.0": + if self.is_inpaint(): + return "stabilityai/stable-diffusion-2-inpainting" + else: + return "stabilityai/stable-diffusion-2" + elif self.version == "2.1": + return "stabilityai/stable-diffusion-2-1" + elif self.version == "2.1-base": + return "stabilityai/stable-diffusion-2-1-base" + elif self.version == "xl-1.0": + if self.is_xl_refiner(): + return "stabilityai/stable-diffusion-xl-refiner-1.0" + else: + return "stabilityai/stable-diffusion-xl-base-1.0" + + raise ValueError(f"Incorrect version {self.version}") + + def short_name(self) -> str: + return self.name().split("/")[-1].replace("stable-diffusion", "sd") + + def clip_embedding_dim(self): + # TODO: can we read from config instead + if self.version in ("1.4", "1.5"): + return 768 + elif self.version in ("2.0", "2.0-base", "2.1", "2.1-base"): + return 1024 + elif self.version in ("xl-1.0") and self.is_xl_base(): + return 768 + else: + raise ValueError(f"Invalid version {self.version}") + + def clipwithproj_embedding_dim(self): + if self.version in ("xl-1.0"): + return 1280 + else: + raise ValueError(f"Invalid version {self.version}") + + def unet_embedding_dim(self): + if self.version in ("1.4", "1.5"): + return 768 + elif self.version in ("2.0", "2.0-base", "2.1", "2.1-base"): + return 1024 + elif self.version in ("xl-1.0") and self.is_xl_base(): + return 2048 + elif self.version in ("xl-1.0") and self.is_xl_refiner(): + return 1280 + else: + raise ValueError(f"Invalid version {self.version}") + + +class BaseModel: + def __init__( + self, + pipeline_info: PipelineInfo, + model, + device, + fp16: bool = False, + max_batch_size: int = 16, + embedding_dim: int = 768, + text_maxlen: int = 77, + ): + self.name = self.__class__.__name__ + + self.pipeline_info = pipeline_info + + self.model = model + self.fp16 = fp16 + self.device = device + + self.min_batch = 1 + self.max_batch = max_batch_size + self.min_image_shape = 256 # min image resolution: 256x256 + self.max_image_shape = 1024 # max image resolution: 1024x1024 + self.min_latent_shape = self.min_image_shape // 8 + self.max_latent_shape = self.max_image_shape // 8 + + self.embedding_dim = embedding_dim + self.text_maxlen = text_maxlen + + def get_ort_optimizer(self): + model_name_to_model_type = { + "CLIP": "clip", + "UNet": "unet", + "VAE": "vae", + "UNetXL": "unet", + "CLIPWithProj": "clip", + } + model_type = model_name_to_model_type[self.name] + return OrtStableDiffusionOptimizer(model_type) + + def get_model(self): + return self.model + + def from_pretrained(self, model_class, framework_model_dir, hf_token, subfolder, **kwargs): + model_dir = os.path.join(framework_model_dir, self.pipeline_info.name(), subfolder) + + if not os.path.exists(model_dir): + model = model_class.from_pretrained( + self.pipeline_info.name(), + subfolder=subfolder, + use_safetensors=self.pipeline_info.use_safetensors(), + use_auth_token=hf_token, + **kwargs, + ).to(self.device) + model.save_pretrained(model_dir) + else: + print(f"Load {self.name} pytorch model from: {model_dir}") + + model = model_class.from_pretrained(model_dir).to(self.device) + return model + + def load_model(self, framework_model_dir: str, hf_token: str, subfolder: str): + pass + + def get_input_names(self) -> List[str]: + pass + + def get_output_names(self) -> List[str]: + pass + + def get_dynamic_axes(self) -> Dict[str, Dict[int, str]]: + pass + + def get_sample_input(self, batch_size, image_height, image_width) -> tuple: + pass + + def get_profile_id(self, batch_size, image_height, image_width, static_batch, static_image_shape): + """For TensorRT EP""" + ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + _, + _, + _, + _, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + + profile_id = f"_b_{batch_size}" if static_batch else f"_b_{min_batch}_{max_batch}" + + if self.name != "CLIP": + if static_image_shape: + profile_id += f"_h_{image_height}_w_{image_width}" + else: + profile_id += f"_h_{min_image_height}_{max_image_height}_w_{min_image_width}_{max_image_width}" + + return profile_id + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + """For TensorRT""" + pass + + def get_shape_dict(self, batch_size, image_height, image_width): + pass + + def fp32_input_output_names(self) -> List[str]: + """For CUDA EP, we export ONNX model with FP32 first, then convert it to mixed precision model. + This is a list of input or output names that are kept as float32 during converting. + For the first version, we will use same data type as TensorRT. + """ + return [] + + def optimize_ort( + self, + input_onnx_path, + optimized_onnx_path, + to_fp16=True, + fp32_op_list=None, + optimize_by_ort=True, + optimize_by_fusion=True, + ): + optimizer = self.get_ort_optimizer() + optimizer.optimize( + input_onnx_path, + optimized_onnx_path, + float16=to_fp16, + keep_io_types=self.fp32_input_output_names(), + fp32_op_list=fp32_op_list, + optimize_by_ort=optimize_by_ort, + optimize_by_fusion=optimize_by_fusion, + ) + + def optimize_trt(self, input_onnx_path, optimized_onnx_path): + onnx_graph = onnx.load(input_onnx_path) + opt = TrtOptimizer(onnx_graph) + opt.cleanup() + opt.fold_constants() + opt.infer_shapes() + opt.cleanup() + onnx_opt_graph = opt.get_optimized_onnx_graph() + + if onnx_opt_graph.ByteSize() > onnx.checker.MAXIMUM_PROTOBUF: + onnx.save_model( + onnx_opt_graph, + optimized_onnx_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + convert_attribute=False, + ) + else: + onnx.save(onnx_opt_graph, optimized_onnx_path) + + def check_dims(self, batch_size, image_height, image_width): + assert batch_size >= self.min_batch and batch_size <= self.max_batch + assert image_height % 8 == 0 or image_width % 8 == 0 + latent_height = image_height // 8 + latent_width = image_width // 8 + assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape + assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape + return (latent_height, latent_width) + + def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape): + min_batch = batch_size if static_batch else self.min_batch + max_batch = batch_size if static_batch else self.max_batch + latent_height = image_height // 8 + latent_width = image_width // 8 + min_image_height = image_height if static_image_shape else self.min_image_shape + max_image_height = image_height if static_image_shape else self.max_image_shape + min_image_width = image_width if static_image_shape else self.min_image_shape + max_image_width = image_width if static_image_shape else self.max_image_shape + min_latent_height = latent_height if static_image_shape else self.min_latent_shape + max_latent_height = latent_height if static_image_shape else self.max_latent_shape + min_latent_width = latent_width if static_image_shape else self.min_latent_shape + max_latent_width = latent_width if static_image_shape else self.max_latent_shape + return ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) + + +class CLIP(BaseModel): + def __init__( + self, + pipeline_info: PipelineInfo, + model, + device, + max_batch_size, + embedding_dim: int = 0, + clip_skip=0, + ): + super().__init__( + pipeline_info, + model=model, + device=device, + max_batch_size=max_batch_size, + embedding_dim=embedding_dim if embedding_dim > 0 else pipeline_info.clip_embedding_dim(), + ) + self.output_hidden_state = pipeline_info.is_xl() + + # see https://github.com/huggingface/diffusers/pull/5057 for more information of clip_skip. + # Clip_skip=1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + self.clip_skip = clip_skip + + def get_input_names(self): + return ["input_ids"] + + def get_output_names(self): + # The exported onnx model has no hidden_state. For SD-XL, We will add hidden_state to optimized onnx model. + return ["text_embeddings"] + + def get_dynamic_axes(self): + return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + self.check_dims(batch_size, image_height, image_width) + min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_image_shape + ) + return { + "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + output = { + "input_ids": (batch_size, self.text_maxlen), + "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), + } + + if self.output_hidden_state: + output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim) + + return output + + def get_sample_input(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return (torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device),) + + def add_hidden_states_graph_output(self, model: ModelProto, optimized_onnx_path, use_external_data_format=False): + graph: GraphProto = model.graph + hidden_layers = -1 + for i in range(len(graph.node)): + for j in range(len(graph.node[i].output)): + name = graph.node[i].output[j] + if "layers" in name: + hidden_layers = max(int(name.split(".")[1].split("/")[0]), hidden_layers) + + assert self.clip_skip >= 0 and self.clip_skip < hidden_layers + + node_output_name = "/text_model/encoder/layers.{}/Add_1_output_0".format(hidden_layers - 1 - self.clip_skip) + + # search the name in outputs of all node + found = False + for i in range(len(graph.node)): + for j in range(len(graph.node[i].output)): + if graph.node[i].output[j] == node_output_name: + found = True + break + if found: + break + if not found: + raise RuntimeError("Failed to find hidden_states graph output in clip") + + # Insert a Cast (fp32 -> fp16) node so that hidden_states has same data type as the first graph output. + graph_output_name = "hidden_states" + cast_node = onnx.helper.make_node("Cast", inputs=[node_output_name], outputs=[graph_output_name]) + cast_node.attribute.extend([onnx.helper.make_attribute("to", graph.output[0].type.tensor_type.elem_type)]) + + hidden_state = graph.output.add() + hidden_state.CopyFrom( + onnx.helper.make_tensor_value_info( + graph_output_name, + graph.output[0].type.tensor_type.elem_type, + ["B", self.text_maxlen, self.embedding_dim], + ) + ) + + onnx_model = OnnxModel(model) + onnx_model.add_node(cast_node) + onnx_model.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format) + + def optimize_ort( + self, + input_onnx_path, + optimized_onnx_path, + to_fp16=True, + fp32_op_list=None, + optimize_by_ort=True, + optimize_by_fusion=True, + ): + optimizer = self.get_ort_optimizer() + + if not self.output_hidden_state: + optimizer.optimize( + input_onnx_path, + optimized_onnx_path, + float16=to_fp16, + keep_io_types=[], + fp32_op_list=fp32_op_list, + keep_outputs=["text_embeddings"], + optimize_by_ort=optimize_by_ort, + optimize_by_fusion=optimize_by_fusion, + ) + elif optimize_by_fusion: + with tempfile.TemporaryDirectory() as tmp_dir: + # Save to a temporary file so that we can load it with Onnx Runtime. + logger.info("Saving a temporary model to add hidden_states to graph output ...") + tmp_model_path = os.path.join(tmp_dir, "model.onnx") + + model = onnx.load(input_onnx_path) + self.add_hidden_states_graph_output(model, tmp_model_path, use_external_data_format=True) + optimizer.optimize( + tmp_model_path, + optimized_onnx_path, + float16=to_fp16, + keep_io_types=[], + fp32_op_list=fp32_op_list, + keep_outputs=["text_embeddings", "hidden_states"], + optimize_by_ort=optimize_by_ort, + optimize_by_fusion=optimize_by_fusion, + ) + else: # input is optimized model, there is no need to add hidden states. + optimizer.optimize( + input_onnx_path, + optimized_onnx_path, + float16=to_fp16, + keep_io_types=[], + fp32_op_list=fp32_op_list, + keep_outputs=["text_embeddings", "hidden_states"], + optimize_by_ort=optimize_by_ort, + optimize_by_fusion=optimize_by_fusion, + ) + + def optimize_trt(self, input_onnx_path, optimized_onnx_path): + onnx_graph = onnx.load(input_onnx_path) + opt = TrtOptimizer(onnx_graph) + opt.select_outputs([0]) # delete graph output#1 + opt.cleanup() + opt.fold_constants() + opt.infer_shapes() + opt.select_outputs([0], names=["text_embeddings"]) # rename network output + opt.cleanup() + onnx_opt_graph = opt.get_optimized_onnx_graph() + if self.output_hidden_state: + self.add_hidden_states_graph_output(onnx_opt_graph, optimized_onnx_path) + else: + onnx.save(onnx_opt_graph, optimized_onnx_path) + + def load_model(self, framework_model_dir, hf_token, subfolder="text_encoder"): + return self.from_pretrained(CLIPTextModel, framework_model_dir, hf_token, subfolder) + + +class CLIPWithProj(CLIP): + def __init__( + self, + pipeline_info: PipelineInfo, + model, + device, + max_batch_size=16, + clip_skip=0, + ): + super().__init__( + pipeline_info, + model, + device=device, + max_batch_size=max_batch_size, + embedding_dim=pipeline_info.clipwithproj_embedding_dim(), + clip_skip=clip_skip, + ) + + def load_model(self, framework_model_dir, hf_token, subfolder="text_encoder_2"): + return self.from_pretrained(CLIPTextModelWithProjection, framework_model_dir, hf_token, subfolder) + + def get_shape_dict(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + output = { + "input_ids": (batch_size, self.text_maxlen), + "text_embeddings": (batch_size, self.embedding_dim), + } + + if self.output_hidden_state: + output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim) + + return output + + +class UNet(BaseModel): + def __init__( + self, + pipeline_info: PipelineInfo, + model, + device, + fp16=False, # used by TRT + max_batch_size=16, + text_maxlen=77, + unet_dim=4, + ): + super().__init__( + pipeline_info, + model=model, + device=device, + fp16=fp16, + max_batch_size=max_batch_size, + embedding_dim=pipeline_info.unet_embedding_dim(), + text_maxlen=text_maxlen, + ) + self.unet_dim = unet_dim + + def load_model(self, framework_model_dir, hf_token, subfolder="unet"): + options = {"variant": "fp16", "torch_dtype": torch.float16} if self.fp16 else {} + return self.from_pretrained(UNet2DConditionModel, framework_model_dir, hf_token, subfolder, **options) + + def get_input_names(self): + return ["sample", "timestep", "encoder_hidden_states"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return { + "sample": {0: "2B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "2B"}, + "latent": {0: "2B", 2: "H", 3: "W"}, + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + return { + "sample": [ + (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), + (2 * batch_size, self.unet_dim, latent_height, latent_width), + (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), + ], + "encoder_hidden_states": [ + (2 * min_batch, self.text_maxlen, self.embedding_dim), + (2 * batch_size, self.text_maxlen, self.embedding_dim), + (2 * max_batch, self.text_maxlen, self.embedding_dim), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), + "timestep": [1], + "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), + "latent": (2 * batch_size, 4, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + dtype = torch.float16 if self.fp16 else torch.float32 + return ( + torch.randn( + 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + ), + torch.tensor([1.0], dtype=torch.float32, device=self.device), + torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + ) + + def fp32_input_output_names(self) -> List[str]: + return ["sample", "timestep"] + + +class UNetXL(BaseModel): + def __init__( + self, + pipeline_info: PipelineInfo, + model, + device, + fp16=False, # used by TRT + max_batch_size=16, + text_maxlen=77, + unet_dim=4, + time_dim=6, + ): + super().__init__( + pipeline_info, + model, + device=device, + fp16=fp16, + max_batch_size=max_batch_size, + embedding_dim=pipeline_info.unet_embedding_dim(), + text_maxlen=text_maxlen, + ) + self.unet_dim = unet_dim + self.time_dim = time_dim + + def load_model(self, framework_model_dir, hf_token, subfolder="unet"): + options = {"variant": "fp16", "torch_dtype": torch.float16} if self.fp16 else {} + return self.from_pretrained(UNet2DConditionModel, framework_model_dir, hf_token, subfolder, **options) + + def get_input_names(self): + return ["sample", "timestep", "encoder_hidden_states", "text_embeds", "time_ids"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return { + "sample": {0: "2B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "2B"}, + "latent": {0: "2B", 2: "H", 3: "W"}, + "text_embeds": {0: "2B"}, + "time_ids": {0: "2B"}, + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + return { + "sample": [ + (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), + (2 * batch_size, self.unet_dim, latent_height, latent_width), + (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), + ], + "encoder_hidden_states": [ + (2 * min_batch, self.text_maxlen, self.embedding_dim), + (2 * batch_size, self.text_maxlen, self.embedding_dim), + (2 * max_batch, self.text_maxlen, self.embedding_dim), + ], + "text_embeds": [(2 * min_batch, 1280), (2 * batch_size, 1280), (2 * max_batch, 1280)], + "time_ids": [ + (2 * min_batch, self.time_dim), + (2 * batch_size, self.time_dim), + (2 * max_batch, self.time_dim), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), + "timestep": (1,), + "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), + "latent": (2 * batch_size, 4, latent_height, latent_width), + "text_embeds": (2 * batch_size, 1280), + "time_ids": (2 * batch_size, self.time_dim), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + dtype = torch.float16 if self.fp16 else torch.float32 + return ( + torch.randn( + 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + ), + torch.tensor([1.0], dtype=torch.float32, device=self.device), + torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + { + "added_cond_kwargs": { + "text_embeds": torch.randn(2 * batch_size, 1280, dtype=dtype, device=self.device), + "time_ids": torch.randn(2 * batch_size, self.time_dim, dtype=dtype, device=self.device), + } + }, + ) + + def fp32_input_output_names(self) -> List[str]: + return ["sample", "timestep"] + + +# VAE Decoder +class VAE(BaseModel): + def __init__( + self, + pipeline_info: PipelineInfo, + model, + device, + max_batch_size, + fp16: bool = False, + custom_fp16_vae: Optional[str] = None, + ): + super().__init__( + pipeline_info, + model=model, + device=device, + fp16=fp16, + max_batch_size=max_batch_size, + ) + + # For SD XL, need custom trained fp16 model to speed up, and avoid overflow at the same time. + self.custom_fp16_vae = custom_fp16_vae + + def load_model(self, framework_model_dir, hf_token: Optional[str] = None, subfolder: str = "vae_decoder"): + model_name = self.custom_fp16_vae or self.pipeline_info.name() + + model_dir = os.path.join(framework_model_dir, model_name, subfolder) + if not os.path.exists(model_dir): + if self.custom_fp16_vae: + vae = AutoencoderKL.from_pretrained(self.custom_fp16_vae, torch_dtype=torch.float16).to(self.device) + else: + vae = AutoencoderKL.from_pretrained( + self.pipeline_info.name(), + subfolder="vae", + use_safetensors=self.pipeline_info.use_safetensors(), + use_auth_token=hf_token, + ).to(self.device) + vae.save_pretrained(model_dir) + else: + print(f"Load {self.name} pytorch model from: {model_dir}") + if self.custom_fp16_vae: + vae = AutoencoderKL.from_pretrained(model_dir, torch_dtype=torch.float16).to(self.device) + else: + vae = AutoencoderKL.from_pretrained(model_dir).to(self.device) + + vae.forward = vae.decode + return vae + + def get_input_names(self): + return ["latent"] + + def get_output_names(self): + return ["images"] + + def get_dynamic_axes(self): + return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + return { + "latent": [ + (min_batch, 4, min_latent_height, min_latent_width), + (batch_size, 4, latent_height, latent_width), + (max_batch, 4, max_latent_height, max_latent_width), + ] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "latent": (batch_size, 4, latent_height, latent_width), + "images": (batch_size, 3, image_height, image_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return (torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device),) + + def fp32_input_output_names(self) -> List[str]: + return [] if self.fp16 else ["latent", "images"] + + +def get_tokenizer(pipeline_info: PipelineInfo, framework_model_dir, hf_token, subfolder="tokenizer"): + tokenizer_dir = os.path.join(framework_model_dir, pipeline_info.name(), subfolder) + + if not os.path.exists(tokenizer_dir): + model = CLIPTokenizer.from_pretrained( + pipeline_info.name(), + subfolder=subfolder, + use_safetensors=pipeline_info.is_xl(), + use_auth_token=hf_token, + ) + model.save_pretrained(tokenizer_dir) + else: + print(f"[I] Load tokenizer pytorch model from: {tokenizer_dir}") + model = CLIPTokenizer.from_pretrained(tokenizer_dir) + return model + + +class TorchVAEEncoder(torch.nn.Module): + def __init__(self, vae_encoder): + super().__init__() + self.vae_encoder = vae_encoder + + def forward(self, x): + return self.vae_encoder.encode(x).latent_dist.sample() + + +class VAEEncoder(BaseModel): + def __init__(self, pipeline_info: PipelineInfo, model, device, max_batch_size): + super().__init__( + pipeline_info, + model=model, + device=device, + max_batch_size=max_batch_size, + ) + + def load_model(self, framework_model_dir, hf_token, subfolder="vae_encoder"): + vae = self.from_pretrained(AutoencoderKL, framework_model_dir, hf_token, subfolder) + return TorchVAEEncoder(vae) + + def get_input_names(self): + return ["images"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return {"images": {0: "B", 2: "8H", 3: "8W"}, "latent": {0: "B", 2: "H", 3: "W"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + self.check_dims(batch_size, image_height, image_width) + + ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + _, + _, + _, + _, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + + return { + "images": [ + (min_batch, 3, min_image_height, min_image_width), + (batch_size, 3, image_height, image_width), + (max_batch, 3, max_image_height, max_image_width), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "images": (batch_size, 3, image_height, image_width), + "latent": (batch_size, 4, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32, device=self.device) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py new file mode 100644 index 0000000000..ec3041e134 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py @@ -0,0 +1,722 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from utilities.py of TensorRT demo diffusion, which has the following license: +# +# Copyright 2022 The HuggingFace Inc. team. +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +from typing import List, Optional + +import numpy as np +import torch + + +class DDIMScheduler: + def __init__( + self, + device="cuda", + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + clip_sample: bool = False, + set_alpha_to_one: bool = False, + steps_offset: int = 1, + prediction_type: str = "epsilon", + ): + # this schedule is very specific to the latent diffusion model. + betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + + alphas = 1.0 - betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + self.steps_offset = steps_offset + self.num_train_timesteps = num_train_timesteps + self.clip_sample = clip_sample + self.prediction_type = prediction_type + self.device = device + + def configure(self): + variance = np.zeros(self.num_inference_steps, dtype=np.float32) + for idx, timestep in enumerate(self.timesteps): + prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps + variance[idx] = self._get_variance(timestep, prev_timestep) + self.variance = torch.from_numpy(variance).to(self.device) + + timesteps = self.timesteps.long().cpu() + self.alphas_cumprod = self.alphas_cumprod[timesteps].to(self.device) + self.final_alpha_cumprod = self.final_alpha_cumprod.to(self.device) + + def scale_model_input(self, sample: torch.FloatTensor, idx, *args, **kwargs) -> torch.FloatTensor: + return sample + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def set_timesteps(self, num_inference_steps: int): + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(self.device) + self.timesteps += self.steps_offset + + def step( + self, + model_output, + sample, + idx, + timestep, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: torch.FloatTensor = None, + ): + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + prev_idx = idx + 1 + alpha_prod_t = self.alphas_cumprod[idx] + alpha_prod_t_prev = ( + self.alphas_cumprod[prev_idx] if prev_idx < self.num_inference_steps else self.final_alpha_cumprod + ) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.prediction_type == "sample": + pred_original_sample = model_output + elif self.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # predict V + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # o_t = sqrt((1 - a_t-1)/(1 - a_t)) * sqrt(1 - a_t/a_t-1) + variance = self.variance[idx] + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the model_output is always re-derived from the clipped x_0 in Glide + model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if eta > 0: + # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 + device = model_output.device + if variance_noise is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + " `variance_noise` stays `None`." + ) + + if variance_noise is None: + variance_noise = torch.randn( + model_output.shape, generator=generator, device=device, dtype=model_output.dtype + ) + variance = variance ** (0.5) * eta * variance_noise + + prev_sample = prev_sample + variance + + return prev_sample + + def add_noise(self, init_latents, noise, idx, latent_timestep): + sqrt_alpha_prod = self.alphas_cumprod[idx] ** 0.5 + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[idx]) ** 0.5 + noisy_latents = sqrt_alpha_prod * init_latents + sqrt_one_minus_alpha_prod * noise + + return noisy_latents + + +class EulerAncestralDiscreteScheduler: + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + device="cuda", + steps_offset=0, + prediction_type="epsilon", + ): + # this schedule is very specific to the latent diffusion model. + betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + + alphas = 1.0 - betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = self.sigmas.max() + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.is_scale_input_called = False + self.device = device + self.num_train_timesteps = num_train_timesteps + self.steps_offset = steps_offset + self.prediction_type = prediction_type + + def scale_model_input(self, sample: torch.FloatTensor, idx, timestep, *args, **kwargs) -> torch.FloatTensor: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + self.is_scale_input_called = True + return sample + + def set_timesteps(self, num_inference_steps: int): + self.num_inference_steps = num_inference_steps + + timesteps = np.linspace(0, self.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy() + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=self.device) + self.timesteps = torch.from_numpy(timesteps).to(device=self.device) + + def configure(self): + dts = np.zeros(self.num_inference_steps, dtype=np.float32) + sigmas_up = np.zeros(self.num_inference_steps, dtype=np.float32) + for idx, timestep in enumerate(self.timesteps): + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + + sigma_from = self.sigmas[step_index] + sigma_to = self.sigmas[step_index + 1] + sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + dt = sigma_down - sigma + dts[idx] = dt + sigmas_up[idx] = sigma_up + + self.dts = torch.from_numpy(dts).to(self.device) + self.sigmas_up = torch.from_numpy(sigmas_up).to(self.device) + + def step( + self, + model_output, + sample, + idx, + timestep, + generator=None, + ): + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + sigma_up = self.sigmas_up[idx] + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + + dt = self.dts[idx] + + prev_sample = sample + derivative * dt + + device = model_output.device + noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(device) + + prev_sample = prev_sample + noise * sigma_up + + return prev_sample + + def add_noise(self, original_samples, noise, idx, timestep=None): + step_index = (self.timesteps == timestep).nonzero().item() + noisy_samples = original_samples + noise * self.sigmas[step_index] + return noisy_samples + + +class UniPCMultistepScheduler: + def __init__( + self, + device="cuda", + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.012, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: Optional[List[int]] = None, + ): + self.device = device + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector if disable_corrector else [] + self.last_sample = None + self.num_train_timesteps = num_train_timesteps + self.solver_order = solver_order + self.prediction_type = prediction_type + self.thresholding = thresholding + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.sample_max_value = sample_max_value + self.solver_type = solver_type + self.lower_order_final = lower_order_final + + def set_timesteps(self, num_inference_steps: int): + timesteps = ( + np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + + self.timesteps = torch.from_numpy(timesteps).to(self.device) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.solver_order + self.lower_order_nums = 0 + self.last_sample = None + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + def convert_model_output( + self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + ) -> torch.FloatTensor: + if self.predict_x0: + if self.prediction_type == "epsilon": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.prediction_type == "sample": + x0_pred = model_output + elif self.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the UniPCMultistepScheduler." + ) + + if self.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.prediction_type == "epsilon": + return model_output + elif self.prediction_type == "sample": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the UniPCMultistepScheduler." + ) + + def multistep_uni_p_bh_update( + self, + model_output: torch.FloatTensor, + prev_timestep: int, + sample: torch.FloatTensor, + order: int, + ) -> torch.FloatTensor: + timestep_list = self.timestep_list + model_output_list = self.model_outputs + + s0, t = self.timestep_list[-1], prev_timestep + m0 = model_output_list[-1] + x = sample + + lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + + h = lambda_t - lambda_s0 + + rks = [] + d1s = [] + for i in range(1, order): + si = timestep_list[-(i + 1)] + mi = model_output_list[-(i + 1)] + lambda_si = self.lambda_t[si] + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + d1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=self.device) + + r = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.solver_type == "bh1": + b_h = hh + elif self.solver_type == "bh2": + b_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + r.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / b_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + r = torch.stack(r) + b = torch.tensor(b, device=self.device) + + if len(d1s) > 0: + d1s = torch.stack(d1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=self.device) + else: + rhos_p = torch.linalg.solve(r[:-1, :-1], b[:-1]) + else: + d1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if d1s is not None: + pred_res = torch.einsum("k,bkchw->bchw", rhos_p, d1s) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * b_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if d1s is not None: + pred_res = torch.einsum("k,bkchw->bchw", rhos_p, d1s) + else: + pred_res = 0 + x_t = x_t_ - sigma_t * b_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.FloatTensor, + this_timestep: int, + last_sample: torch.FloatTensor, + # this_sample: torch.FloatTensor, + order: int, + ) -> torch.FloatTensor: + timestep_list = self.timestep_list + model_output_list = self.model_outputs + + s0, t = timestep_list[-1], this_timestep + m0 = model_output_list[-1] + x = last_sample + # x_t = this_sample + model_t = this_model_output + + lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + + h = lambda_t - lambda_s0 + + rks = [] + d1s = [] + for i in range(1, order): + si = timestep_list[-(i + 1)] + mi = model_output_list[-(i + 1)] + lambda_si = self.lambda_t[si] + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + d1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=self.device) + + r = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.solver_type == "bh1": + b_h = hh + elif self.solver_type == "bh2": + b_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + r.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / b_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + r = torch.stack(r) + b = torch.tensor(b, device=self.device) + + if len(d1s) > 0: + d1s = torch.stack(d1s, dim=1) + else: + d1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=self.device) + else: + rhos_c = torch.linalg.solve(r, b) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if d1s is not None: + corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], d1s) + else: + corr_res = 0 + d1_t = model_t - m0 + x_t = x_t_ - alpha_t * b_h * (corr_res + rhos_c[-1] * d1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if d1s is not None: + corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], d1s) + else: + corr_res = 0 + d1_t = model_t - m0 + x_t = x_t_ - sigma_t * b_h * (corr_res + rhos_c[-1] * d1_t) + x_t = x_t.to(x.dtype) + return x_t + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + return_dict: bool = True, + ): + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.device) + step_index = (self.timesteps == timestep).nonzero() + if len(step_index) == 0: + step_index = len(self.timesteps) - 1 + else: + step_index = step_index.item() + + use_corrector = step_index > 0 and step_index - 1 not in self.disable_corrector and self.last_sample is not None + + model_output_convert = self.convert_model_output(model_output, timestep, sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + this_timestep=timestep, + last_sample=self.last_sample, + # this_sample=sample, + order=self.this_order, + ) + + # now prepare to run the predictor + prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] + + for i in range(self.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + if self.lower_order_final: + this_order = min(self.solver_order, len(self.timesteps) - step_index) + else: + this_order = self.solver_order + + self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + prev_timestep=prev_timestep, + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.solver_order: + self.lower_order_nums += 1 + + if not return_dict: + return (prev_sample,) + + return prev_sample + + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + idx, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=self.device, dtype=original_samples.dtype) + timesteps = timesteps.to(self.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def configure(self): + pass + + def __len__(self): + return self.num_train_timesteps diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py new file mode 100644 index 0000000000..dfdfa007d7 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py @@ -0,0 +1,204 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import os +from enum import Enum + +import torch +from diffusion_models import CLIP, VAE, CLIPWithProj, PipelineInfo, UNet, UNetXL + + +class EngineType(Enum): + ORT_CUDA = 0 # ONNX Runtime CUDA Execution Provider + ORT_TRT = 1 # ONNX Runtime TensorRT Execution Provider + TRT = 2 # TensorRT + TORCH = 3 # PyTorch + + +def get_engine_type(name: str) -> EngineType: + name_to_type = { + "ORT_CUDA": EngineType.ORT_CUDA, + "ORT_TRT": EngineType.ORT_TRT, + "TRT": EngineType.TRT, + "TORCH": EngineType.TORCH, + } + return name_to_type[name] + + +class EngineBuilder: + def __init__( + self, + engine_type: EngineType, + pipeline_info: PipelineInfo, + device="cuda", + max_batch_size=16, + hf_token=None, + use_cuda_graph=False, + ): + """ + Initializes the Engine Builder. + + Args: + pipeline_info (PipelineInfo): + Version and Type of pipeline. + device (str | torch.device): + device to run engine + max_batch_size (int): + Maximum batch size for dynamic batch engine. + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + """ + self.engine_type = engine_type + self.pipeline_info = pipeline_info + self.max_batch_size = max_batch_size + self.hf_token = hf_token + self.use_cuda_graph = use_cuda_graph + self.device = torch.device(device) + self.torch_device = torch.device(device, torch.cuda.current_device()) + self.stages = pipeline_info.stages() + + # TODO: use custom fp16 for ORT_TRT, and no need to fallback to torch. + self.vae_torch_fallback = self.pipeline_info.is_xl() and engine_type != EngineType.ORT_CUDA + + # For SD XL, use an VAE that modified to run in fp16 precision without generating NaNs. + self.custom_fp16_vae = ( + "madebyollin/sdxl-vae-fp16-fix" + if self.pipeline_info.is_xl() and self.engine_type == EngineType.ORT_CUDA + else None + ) + + self.models = {} + self.engines = {} + self.torch_models = {} + self.use_vae_slicing = False + + def enable_vae_slicing(self): + self.use_vae_slicing = True + + def teardown(self): + for engine in self.engines.values(): + del engine + self.engines = {} + + def get_cached_model_name(self, model_name): + if self.pipeline_info.is_inpaint(): + model_name += "_inpaint" + return model_name + + def get_onnx_path(self, model_name, onnx_dir, opt=True, suffix=""): + engine_name = self.engine_type.name.lower() + directory_name = self.get_cached_model_name(model_name) + (f".{engine_name}" if opt else "") + suffix + onnx_model_dir = os.path.join(onnx_dir, directory_name) + os.makedirs(onnx_model_dir, exist_ok=True) + return os.path.join(onnx_model_dir, "model.onnx") + + def get_engine_path(self, engine_dir, model_name, profile_id): + return os.path.join(engine_dir, self.get_cached_model_name(model_name) + profile_id) + + def load_models(self, framework_model_dir: str): + # Disable torch SDPA since torch 2.0.* cannot export it to ONNX + if hasattr(torch.nn.functional, "scaled_dot_product_attention"): + delattr(torch.nn.functional, "scaled_dot_product_attention") + + # For TRT or ORT_TRT, we will export fp16 torch model for UNet. + # For ORT_CUDA, we export fp32 model first, then optimize to fp16. + export_fp16_unet = self.engine_type in [EngineType.ORT_TRT, EngineType.TRT] + + if "clip" in self.stages: + self.models["clip"] = CLIP( + self.pipeline_info, + None, # not loaded yet + device=self.torch_device, + max_batch_size=self.max_batch_size, + clip_skip=0, + ) + + if "clip2" in self.stages: + self.models["clip2"] = CLIPWithProj( + self.pipeline_info, + None, # not loaded yet + device=self.torch_device, + max_batch_size=self.max_batch_size, + clip_skip=0, + ) + + if "unet" in self.stages: + self.models["unet"] = UNet( + self.pipeline_info, + None, # not loaded yet + device=self.torch_device, + fp16=export_fp16_unet, + max_batch_size=self.max_batch_size, + unet_dim=(9 if self.pipeline_info.is_inpaint() else 4), + ) + + if "unetxl" in self.stages: + self.models["unetxl"] = UNetXL( + self.pipeline_info, + None, # not loaded yet + device=self.torch_device, + fp16=export_fp16_unet, + max_batch_size=self.max_batch_size, + unet_dim=4, + time_dim=(5 if self.pipeline_info.is_xl_refiner() else 6), + ) + + # VAE Decoder + if "vae" in self.stages: + self.models["vae"] = VAE( + self.pipeline_info, + None, # not loaded yet + device=self.torch_device, + max_batch_size=self.max_batch_size, + custom_fp16_vae=self.custom_fp16_vae, + ) + + if self.vae_torch_fallback: + self.torch_models["vae"] = self.models["vae"].load_model(framework_model_dir, self.hf_token) + + def load_resources(self, image_height, image_width, batch_size): + # Allocate buffers for I/O bindings + for model_name, obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + slice_size = 1 if (model_name == "vae" and self.use_vae_slicing) else batch_size + self.engines[model_name].allocate_buffers( + shape_dict=obj.get_shape_dict(slice_size, image_height, image_width), device=self.torch_device + ) + + def _vae_decode(self, latents): + if self.vae_torch_fallback: + if not self.custom_fp16_vae: + latents = latents.to(dtype=torch.float32) + self.torch_models["vae"] = self.torch_models["vae"].to(dtype=torch.float32) + images = self.torch_models["vae"](latents)["sample"] + else: + images = self.run_engine("vae", {"latent": latents})["images"] + + return images + + def vae_decode(self, latents): + if self.use_vae_slicing: + # The output tensor points to same buffer. Need clone it to avoid overwritten. + decoded_slices = [self._vae_decode(z_slice).clone() for z_slice in latents.split(1)] + return torch.cat(decoded_slices) + + return self._vae_decode(latents) + + +def get_engine_paths(work_dir: str, pipeline_info: PipelineInfo, engine_type: EngineType): + root_dir = work_dir or "." + short_name = pipeline_info.short_name() + + # When both ORT_CUDA and ORT_TRT/TRT is used, we shall make sub directory for each engine since + # ORT_CUDA need fp32 torch model, while ORT_TRT/TRT use fp16 torch model. + onnx_dir = os.path.join(root_dir, engine_type.name, short_name, "onnx") + engine_dir = os.path.join(root_dir, engine_type.name, short_name, "engine") + output_dir = os.path.join(root_dir, engine_type.name, short_name, "output") + timing_cache = os.path.join(root_dir, engine_type.name, "timing_cache") + framework_model_dir = os.path.join(root_dir, engine_type.name, "torch_model") + + return onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py new file mode 100644 index 0000000000..07c675b2ed --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py @@ -0,0 +1,296 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import gc +import logging +import os +import shutil +from typing import List, Optional + +import torch +from diffusion_models import PipelineInfo +from engine_builder import EngineBuilder, EngineType +from ort_utils import CudaSession + +import onnxruntime as ort + +logger = logging.getLogger(__name__) + + +class OrtCudaEngine(CudaSession): + def __init__( + self, + onnx_path, + device_id: int = 0, + enable_cuda_graph: bool = False, + disable_optimization: bool = False, + ): + self.onnx_path = onnx_path + self.provider = "CUDAExecutionProvider" + self.provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph) + # self.provider_options["enable_skip_layer_norm_strict_mode"] = True + + session_options = ort.SessionOptions() + + # When the model has been optimized by onnxruntime, we can disable optimization to save session creation time. + if disable_optimization: + session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL + + logger.info("creating CUDA EP session for %s", onnx_path) + ort_session = ort.InferenceSession( + onnx_path, + session_options, + providers=[ + (self.provider, self.provider_options), + "CPUExecutionProvider", + ], + ) + logger.info("created CUDA EP session for %s", onnx_path) + + device = torch.device("cuda", device_id) + super().__init__(ort_session, device, enable_cuda_graph) + + def allocate_buffers(self, shape_dict, device): + super().allocate_buffers(shape_dict) + + +class _ModelConfig: + """ + Configuration of one model (like Clip, UNet etc) on ONNX export and optimization for CUDA provider. + For example, if you want to use fp32 in layer normalization, set the following: + force_fp32_ops=["SkipLayerNormalization", "LayerNormalization"] + """ + + def __init__( + self, + onnx_opset_version: int, + use_cuda_graph: bool, + fp16: bool = True, + force_fp32_ops: Optional[List[str]] = None, + optimize_by_ort: bool = True, + ): + self.onnx_opset_version = onnx_opset_version + self.use_cuda_graph = use_cuda_graph + self.fp16 = fp16 + self.force_fp32_ops = force_fp32_ops + self.optimize_by_ort = optimize_by_ort + + +class OrtCudaEngineBuilder(EngineBuilder): + def __init__( + self, + pipeline_info: PipelineInfo, + max_batch_size=16, + hf_token=None, + device="cuda", + use_cuda_graph=False, + ): + """ + Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder. + + Args: + pipeline_info (PipelineInfo): + Version and Type of pipeline. + max_batch_size (int): + Maximum batch size for dynamic batch engine. + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + device (str): + device to run. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + """ + super().__init__( + EngineType.ORT_CUDA, + pipeline_info, + max_batch_size=max_batch_size, + hf_token=hf_token, + device=device, + use_cuda_graph=use_cuda_graph, + ) + + self.model_config = {} + + def _configure( + self, + model_name: str, + onnx_opset_version: int, + use_cuda_graph: bool, + fp16: bool = True, + force_fp32_ops: Optional[List[str]] = None, + optimize_by_ort: bool = True, + ): + self.model_config[model_name] = _ModelConfig( + onnx_opset_version, + use_cuda_graph, + fp16=fp16, + force_fp32_ops=force_fp32_ops, + optimize_by_ort=optimize_by_ort, + ) + + def configure_xl(self, onnx_opset_version: int): + self._configure( + "clip", + onnx_opset_version=onnx_opset_version, + use_cuda_graph=self.use_cuda_graph, + ) + self._configure( + "clip2", + onnx_opset_version=onnx_opset_version, # TODO: ArgMax-12 is not implemented in CUDA + use_cuda_graph=False, # TODO: fix Runtime Error with cuda graph + ) + self._configure( + "unetxl", + onnx_opset_version=onnx_opset_version, + use_cuda_graph=self.use_cuda_graph, + ) + + self._configure( + "vae", + onnx_opset_version=onnx_opset_version, + use_cuda_graph=self.use_cuda_graph, + ) + + def build_engines( + self, + engine_dir: str, + framework_model_dir: str, + onnx_dir: str, + onnx_opset_version: int = 17, + opt_image_height: int = 512, + opt_image_width: int = 512, + opt_batch_size: int = 1, + force_engine_rebuild: bool = False, + device_id: int = 0, + save_fp32_intermediate_model=False, + ): + self.torch_device = torch.device("cuda", device_id) + self.load_models(framework_model_dir) + + if force_engine_rebuild: + if os.path.isdir(onnx_dir): + logger.info("Remove existing directory %s since force_engine_rebuild is enabled", onnx_dir) + shutil.rmtree(onnx_dir) + if os.path.isdir(engine_dir): + logger.info("Remove existing directory %s since force_engine_rebuild is enabled", engine_dir) + shutil.rmtree(engine_dir) + + if not os.path.isdir(engine_dir): + os.makedirs(engine_dir) + + if not os.path.isdir(onnx_dir): + os.makedirs(onnx_dir) + + # Add default configuration if missing + if self.pipeline_info.is_xl(): + self.configure_xl(onnx_opset_version) + for model_name in self.models: + if model_name not in self.model_config: + self.model_config[model_name] = _ModelConfig(onnx_opset_version, self.use_cuda_graph) + + # Export models to ONNX + for model_name, model_obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + + onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False) + onnx_fp32_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=".fp32") + onnx_fp16_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=".fp16") + onnx_opt_path = onnx_fp16_path if self.model_config[model_name].fp16 else onnx_fp32_path + if not os.path.exists(onnx_opt_path): + if not os.path.exists(onnx_path): + print("----") + logger.info("Exporting model: %s", onnx_path) + model = model_obj.load_model(framework_model_dir, self.hf_token) + if model_name == "vae": + model.to(torch.float32) + + with torch.inference_mode(): + # For CUDA EP, export FP32 onnx since some graph fusion only supports fp32 graph pattern. + inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) + + torch.onnx.export( + model, + inputs, + onnx_path, + export_params=True, + opset_version=self.model_config[model_name].onnx_opset_version, + do_constant_folding=True, + input_names=model_obj.get_input_names(), + output_names=model_obj.get_output_names(), + dynamic_axes=model_obj.get_dynamic_axes(), + ) + del model + torch.cuda.empty_cache() + gc.collect() + else: + logger.info("Found cached model: %s", onnx_path) + + # Generate fp32 optimized model. + # If final target is fp16 model, we save fp32 optimized model so that it is easy to tune + # fp16 conversion. That could save a lot of time in developing. + use_fp32_intermediate = save_fp32_intermediate_model and self.model_config[model_name].fp16 + if use_fp32_intermediate: + if not os.path.exists(onnx_fp32_path): + print("------") + logger.info("Generating optimized model: %s", onnx_fp32_path) + + # There is risk that some ORT fused ops fp32 only. So far, we have not encountered such issue. + model_obj.optimize_ort( + onnx_path, + onnx_fp32_path, + to_fp16=False, + fp32_op_list=self.model_config[model_name].force_fp32_ops, + optimize_by_ort=self.model_config[model_name].optimize_by_ort, + ) + else: + logger.info("Found cached optimized model: %s", onnx_fp32_path) + + # Generate the final optimized model. + if not os.path.exists(onnx_opt_path): + print("------") + logger.info("Generating optimized model: %s", onnx_opt_path) + + # When there is fp32 intermediate optimized model, this will just convert model from fp32 to fp16. + optimize_by_ort = False if use_fp32_intermediate else self.model_config[model_name].optimize_by_ort + + model_obj.optimize_ort( + onnx_fp32_path if use_fp32_intermediate else onnx_path, + onnx_opt_path, + to_fp16=self.model_config[model_name].fp16, + fp32_op_list=self.model_config[model_name].force_fp32_ops, + optimize_by_ort=optimize_by_ort, + optimize_by_fusion=not use_fp32_intermediate, + ) + else: + logger.info("Found cached optimized model: %s", onnx_opt_path) + + built_engines = {} + for model_name in self.models: + if model_name == "vae" and self.vae_torch_fallback: + continue + + onnx_fp32_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=".fp32") + onnx_fp16_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=".fp16") + onnx_opt_path = onnx_fp16_path if self.model_config[model_name].fp16 else onnx_fp32_path + + use_cuda_graph = self.model_config[model_name].use_cuda_graph + + engine = OrtCudaEngine( + onnx_opt_path, + device_id=device_id, + enable_cuda_graph=use_cuda_graph, + disable_optimization=False, + ) + + logger.info("%s options for %s: %s", engine.provider, model_name, engine.provider_options) + built_engines[model_name] = engine + + self.engines = built_engines + + return built_engines + + def run_engine(self, model_name, feed_dict): + return self.engines[model_name].infer(feed_dict) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py new file mode 100644 index 0000000000..8a39dc2ed6 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py @@ -0,0 +1,263 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import gc +import logging +import os +import shutil + +import torch +from cuda import cudart +from diffusion_models import PipelineInfo +from engine_builder import EngineBuilder, EngineType +from ort_utils import CudaSession + +import onnxruntime as ort + +logger = logging.getLogger(__name__) + + +class OrtTensorrtEngine(CudaSession): + def __init__(self, engine_path, device_id, onnx_path, fp16, input_profile, workspace_size, enable_cuda_graph): + self.engine_path = engine_path + self.ort_trt_provider_options = self.get_tensorrt_provider_options( + input_profile, + workspace_size, + fp16, + device_id, + enable_cuda_graph, + ) + + session_options = ort.SessionOptions() + session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL + logger.info("creating TRT EP session for %s", onnx_path) + ort_session = ort.InferenceSession( + onnx_path, + session_options, + providers=[ + ("TensorrtExecutionProvider", self.ort_trt_provider_options), + ], + ) + logger.info("created TRT EP session for %s", onnx_path) + + device = torch.device("cuda", device_id) + super().__init__(ort_session, device, enable_cuda_graph) + + def get_tensorrt_provider_options(self, input_profile, workspace_size, fp16, device_id, enable_cuda_graph): + trt_ep_options = { + "device_id": device_id, + "trt_fp16_enable": fp16, + "trt_engine_cache_enable": True, + "trt_timing_cache_enable": True, + "trt_detailed_build_log": True, + "trt_engine_cache_path": self.engine_path, + } + + if enable_cuda_graph: + trt_ep_options["trt_cuda_graph_enable"] = True + + if workspace_size > 0: + trt_ep_options["trt_max_workspace_size"] = workspace_size + + if input_profile: + min_shapes = [] + max_shapes = [] + opt_shapes = [] + for name, profile in input_profile.items(): + assert isinstance(profile, list) and len(profile) == 3 + min_shape = profile[0] + opt_shape = profile[1] + max_shape = profile[2] + assert len(min_shape) == len(opt_shape) and len(opt_shape) == len(max_shape) + + min_shapes.append(f"{name}:" + "x".join([str(x) for x in min_shape])) + opt_shapes.append(f"{name}:" + "x".join([str(x) for x in opt_shape])) + max_shapes.append(f"{name}:" + "x".join([str(x) for x in max_shape])) + + trt_ep_options["trt_profile_min_shapes"] = ",".join(min_shapes) + trt_ep_options["trt_profile_max_shapes"] = ",".join(max_shapes) + trt_ep_options["trt_profile_opt_shapes"] = ",".join(opt_shapes) + + logger.info("trt_ep_options=%s", trt_ep_options) + + return trt_ep_options + + def allocate_buffers(self, shape_dict, device): + super().allocate_buffers(shape_dict) + + +class OrtTensorrtEngineBuilder(EngineBuilder): + def __init__( + self, + pipeline_info: PipelineInfo, + max_batch_size=16, + hf_token=None, + device="cuda", + use_cuda_graph=False, + ): + """ + Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder. + + Args: + pipeline_info (PipelineInfo): + Version and Type of pipeline. + max_batch_size (int): + Maximum batch size for dynamic batch engine. + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + device (str): + device to run. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + """ + super().__init__( + EngineType.ORT_TRT, + pipeline_info, + max_batch_size=max_batch_size, + hf_token=hf_token, + device=device, + use_cuda_graph=use_cuda_graph, + ) + + def has_engine_file(self, engine_path): + if os.path.isdir(engine_path): + children = os.scandir(engine_path) + for entry in children: + if entry.is_file() and entry.name.endswith(".engine"): + return True + return False + + def get_work_space_size(self, model_name, max_workspace_size): + gibibyte = 2**30 + workspace_size = 4 * gibibyte if model_name == "clip" else max_workspace_size + if workspace_size == 0: + _, free_mem, _ = cudart.cudaMemGetInfo() + # The following logic are adopted from TensorRT demo diffusion. + if free_mem > 6 * gibibyte: + workspace_size = free_mem - 4 * gibibyte + return workspace_size + + def build_engines( + self, + engine_dir, + framework_model_dir, + onnx_dir, + onnx_opset, + opt_image_height, + opt_image_width, + opt_batch_size=1, + force_engine_rebuild=False, + static_batch=False, + static_image_shape=True, + max_workspace_size=0, + device_id=0, + ): + self.torch_device = torch.device("cuda", device_id) + self.load_models(framework_model_dir) + + if force_engine_rebuild: + if os.path.isdir(onnx_dir): + logger.info("Remove existing directory %s since force_engine_rebuild is enabled", onnx_dir) + shutil.rmtree(onnx_dir) + if os.path.isdir(engine_dir): + logger.info("Remove existing directory %s since force_engine_rebuild is enabled", engine_dir) + shutil.rmtree(engine_dir) + + if not os.path.isdir(engine_dir): + os.makedirs(engine_dir) + + if not os.path.isdir(onnx_dir): + os.makedirs(onnx_dir) + + # Export models to ONNX + for model_name, model_obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + + profile_id = model_obj.get_profile_id( + opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape + ) + engine_path = self.get_engine_path(engine_dir, model_name, profile_id) + if not self.has_engine_file(engine_path): + onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False) + onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) + if not os.path.exists(onnx_opt_path): + if not os.path.exists(onnx_path): + logger.info(f"Exporting model: {onnx_path}") + model = model_obj.load_model(framework_model_dir, self.hf_token) + with torch.inference_mode(), torch.autocast("cuda"): + inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) + torch.onnx.export( + model, + inputs, + onnx_path, + export_params=True, + opset_version=onnx_opset, + do_constant_folding=True, + input_names=model_obj.get_input_names(), + output_names=model_obj.get_output_names(), + dynamic_axes=model_obj.get_dynamic_axes(), + ) + del model + torch.cuda.empty_cache() + gc.collect() + else: + logger.info("Found cached model: %s", onnx_path) + + # Optimize onnx + if not os.path.exists(onnx_opt_path): + logger.info("Generating optimizing model: %s", onnx_opt_path) + model_obj.optimize_trt(onnx_path, onnx_opt_path) + else: + logger.info("Found cached optimized model: %s", onnx_opt_path) + + built_engines = {} + for model_name, model_obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + + profile_id = model_obj.get_profile_id( + opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape + ) + + engine_path = self.get_engine_path(engine_dir, model_name, profile_id) + onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) + + if not self.has_engine_file(engine_path): + logger.info( + "Building TensorRT engine for %s from %s to %s. It can take a while to complete...", + model_name, + onnx_opt_path, + engine_path, + ) + else: + logger.info("Reuse cached TensorRT engine in directory %s", engine_path) + + input_profile = model_obj.get_input_profile( + opt_batch_size, + opt_image_height, + opt_image_width, + static_batch=static_batch, + static_image_shape=static_image_shape, + ) + + engine = OrtTensorrtEngine( + engine_path, + device_id, + onnx_opt_path, + fp16=True, + input_profile=input_profile, + workspace_size=self.get_work_space_size(model_name, max_workspace_size), + enable_cuda_graph=self.use_cuda_graph, + ) + + built_engines[model_name] = engine + + self.engines = built_engines + + return built_engines + + def run_engine(self, model_name, feed_dict): + return self.engines[model_name].infer(feed_dict) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py new file mode 100644 index 0000000000..4a924abfb8 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py @@ -0,0 +1,507 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import gc +import os +import pathlib +from collections import OrderedDict + +import numpy as np +import onnx +import onnx_graphsurgeon as gs +import tensorrt as trt +import torch +from cuda import cudart +from diffusion_models import PipelineInfo +from engine_builder import EngineBuilder, EngineType +from polygraphy.backend.common import bytes_from_path +from polygraphy.backend.trt import ( + CreateConfig, + ModifyNetworkOutputs, + Profile, + engine_from_bytes, + engine_from_network, + network_from_onnx_path, + save_engine, +) +from trt_utilities import TRT_LOGGER + +# Map of numpy dtype -> torch dtype +numpy_to_torch_dtype_dict = { + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, +} + + +def _cuda_assert(cuda_ret): + err = cuda_ret[0] + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError( + f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t" + ) + if len(cuda_ret) > 1: + return cuda_ret[1] + return None + + +class TensorrtEngine: + def __init__( + self, + engine_path, + ): + self.engine_path = engine_path + self.engine = None + self.context = None + self.buffers = OrderedDict() + self.tensors = OrderedDict() + self.cuda_graph_instance = None + + def __del__(self): + del self.engine + del self.context + del self.buffers + del self.tensors + + def refit(self, onnx_path, onnx_refit_path): + def convert_int64(arr): + if len(arr.shape) == 0: + return np.int32(arr) + return arr + + def add_to_map(refit_dict, name, values): + if name in refit_dict: + assert refit_dict[name] is None + if values.dtype == np.int64: + values = convert_int64(values) + refit_dict[name] = values + + print(f"Refitting TensorRT engine with {onnx_refit_path} weights") + refit_nodes = gs.import_onnx(onnx.load(onnx_refit_path)).toposort().nodes + + # Construct mapping from weight names in refit model -> original model + name_map = {} + for n, node in enumerate(gs.import_onnx(onnx.load(onnx_path)).toposort().nodes): + refit_node = refit_nodes[n] + assert node.op == refit_node.op + # Constant nodes in ONNX do not have inputs but have a constant output + if node.op == "Constant": + name_map[refit_node.outputs[0].name] = node.outputs[0].name + # Handle scale and bias weights + elif node.op == "Conv": + if node.inputs[1].__class__ == gs.Constant: + name_map[refit_node.name + "_TRTKERNEL"] = node.name + "_TRTKERNEL" + if node.inputs[2].__class__ == gs.Constant: + name_map[refit_node.name + "_TRTBIAS"] = node.name + "_TRTBIAS" + # For all other nodes: find node inputs that are initializers (gs.Constant) + else: + for i, inp in enumerate(node.inputs): + if inp.__class__ == gs.Constant: + name_map[refit_node.inputs[i].name] = inp.name + + def map_name(name): + if name in name_map: + return name_map[name] + return name + + # Construct refit dictionary + refit_dict = {} + refitter = trt.Refitter(self.engine, TRT_LOGGER) + all_weights = refitter.get_all() + for layer_name, role in zip(all_weights[0], all_weights[1]): + # for specialized roles, use a unique name in the map: + if role == trt.WeightsRole.KERNEL: + name = layer_name + "_TRTKERNEL" + elif role == trt.WeightsRole.BIAS: + name = layer_name + "_TRTBIAS" + else: + name = layer_name + + assert name not in refit_dict, "Found duplicate layer: " + name + refit_dict[name] = None + + for n in refit_nodes: + # Constant nodes in ONNX do not have inputs but have a constant output + if n.op == "Constant": + name = map_name(n.outputs[0].name) + print(f"Add Constant {name}\n") + add_to_map(refit_dict, name, n.outputs[0].values) + + # Handle scale and bias weights + elif n.op == "Conv": + if n.inputs[1].__class__ == gs.Constant: + name = map_name(n.name + "_TRTKERNEL") + add_to_map(refit_dict, name, n.inputs[1].values) + + if n.inputs[2].__class__ == gs.Constant: + name = map_name(n.name + "_TRTBIAS") + add_to_map(refit_dict, name, n.inputs[2].values) + + # For all other nodes: find node inputs that are initializers (AKA gs.Constant) + else: + for inp in n.inputs: + name = map_name(inp.name) + if inp.__class__ == gs.Constant: + add_to_map(refit_dict, name, inp.values) + + for layer_name, weights_role in zip(all_weights[0], all_weights[1]): + if weights_role == trt.WeightsRole.KERNEL: + custom_name = layer_name + "_TRTKERNEL" + elif weights_role == trt.WeightsRole.BIAS: + custom_name = layer_name + "_TRTBIAS" + else: + custom_name = layer_name + + # Skip refitting Trilu for now; scalar weights of type int64 value 1 - for clip model + if layer_name.startswith("onnx::Trilu"): + continue + + if refit_dict[custom_name] is not None: + refitter.set_weights(layer_name, weights_role, refit_dict[custom_name]) + else: + print(f"[W] No refit weights for layer: {layer_name}") + + if not refitter.refit_cuda_engine(): + print("Failed to refit!") + exit(0) + + def build( + self, + onnx_path, + fp16, + input_profile=None, + enable_refit=False, + enable_preview=False, + enable_all_tactics=False, + timing_cache=None, + update_output_names=None, + ): + print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") + p = Profile() + if input_profile: + for name, dims in input_profile.items(): + assert len(dims) == 3 + p.add(name, min=dims[0], opt=dims[1], max=dims[2]) + + config_kwargs = {} + if not enable_all_tactics: + config_kwargs["tactic_sources"] = [] + + network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) + if update_output_names: + print(f"Updating network outputs to {update_output_names}") + network = ModifyNetworkOutputs(network, update_output_names) + engine = engine_from_network( + network, + config=CreateConfig( + fp16=fp16, refittable=enable_refit, profiles=[p], load_timing_cache=timing_cache, **config_kwargs + ), + save_timing_cache=timing_cache, + ) + save_engine(engine, path=self.engine_path) + + def load(self): + print(f"Loading TensorRT engine: {self.engine_path}") + self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) + + def activate(self, reuse_device_memory=None): + if reuse_device_memory: + self.context = self.engine.create_execution_context_without_device_memory() + self.context.device_memory = reuse_device_memory + else: + self.context = self.engine.create_execution_context() + + def allocate_buffers(self, shape_dict=None, device="cuda"): + for idx in range(self.engine.num_io_tensors): + binding = self.engine[idx] + if shape_dict and binding in shape_dict: + shape = shape_dict[binding] + else: + shape = self.engine.get_binding_shape(binding) + dtype = trt.nptype(self.engine.get_binding_dtype(binding)) + if self.engine.binding_is_input(binding): + self.context.set_binding_shape(idx, shape) + tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) + self.tensors[binding] = tensor + + def infer(self, feed_dict, stream, use_cuda_graph=False): + for name, buf in feed_dict.items(): + self.tensors[name].copy_(buf) + + for name, tensor in self.tensors.items(): + self.context.set_tensor_address(name, tensor.data_ptr()) + + if use_cuda_graph: + if self.cuda_graph_instance is not None: + _cuda_assert(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream)) + _cuda_assert(cudart.cudaStreamSynchronize(stream)) + else: + # do inference before CUDA graph capture + noerror = self.context.execute_async_v3(stream) + if not noerror: + raise ValueError("ERROR: inference failed.") + # capture cuda graph + _cuda_assert( + cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal) + ) + self.context.execute_async_v3(stream) + self.graph = _cuda_assert(cudart.cudaStreamEndCapture(stream)) + + from cuda import nvrtc + + result, major, minor = nvrtc.nvrtcVersion() + assert result == nvrtc.nvrtcResult(0) + if major < 12: + self.cuda_graph_instance = _cuda_assert( + cudart.cudaGraphInstantiate(self.graph, b"", 0) + ) # cuda < 12 + else: + self.cuda_graph_instance = _cuda_assert(cudart.cudaGraphInstantiate(self.graph, 0)) # cuda >= 12 + else: + noerror = self.context.execute_async_v3(stream) + if not noerror: + raise ValueError("ERROR: inference failed.") + + return self.tensors + + +class TensorrtEngineBuilder(EngineBuilder): + """ + Helper class to hide the detail of TensorRT Engine from pipeline. + """ + + def __init__( + self, + pipeline_info: PipelineInfo, + max_batch_size=16, + hf_token=None, + device="cuda", + use_cuda_graph=False, + ): + """ + Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder. + + Args: + pipeline_info (PipelineInfo): + Version and Type of pipeline. + max_batch_size (int): + Maximum batch size for dynamic batch engine. + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + device (str): + device to run. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + """ + super().__init__( + EngineType.TRT, + pipeline_info, + max_batch_size=max_batch_size, + hf_token=hf_token, + device=device, + use_cuda_graph=use_cuda_graph, + ) + + self.stream = None + self.shared_device_memory = None + + def load_resources(self, image_height, image_width, batch_size): + super().load_resources(image_height, image_width, batch_size) + + self.stream = _cuda_assert(cudart.cudaStreamCreate()) + + def teardown(self): + super().teardown() + + if self.shared_device_memory: + cudart.cudaFree(self.shared_device_memory) + + cudart.cudaStreamDestroy(self.stream) + del self.stream + + def load_engines( + self, + engine_dir, + framework_model_dir, + onnx_dir, + onnx_opset, + opt_batch_size, + opt_image_height, + opt_image_width, + force_export=False, + force_optimize=False, + force_build=False, + static_batch=False, + static_shape=True, + enable_refit=False, + enable_preview=False, + enable_all_tactics=False, + timing_cache=None, + onnx_refit_dir=None, + ): + """ + Build and load engines for TensorRT accelerated inference. + Export ONNX models first, if applicable. + + Args: + engine_dir (str): + Directory to write the TensorRT engines. + framework_model_dir (str): + Directory to write the framework model ckpt. + onnx_dir (str): + Directory to write the ONNX models. + onnx_opset (int): + ONNX opset version to export the models. + opt_batch_size (int): + Batch size to optimize for during engine building. + opt_image_height (int): + Image height to optimize for during engine building. Must be a multiple of 8. + opt_image_width (int): + Image width to optimize for during engine building. Must be a multiple of 8. + force_export (bool): + Force re-exporting the ONNX models. + force_optimize (bool): + Force re-optimizing the ONNX models. + force_build (bool): + Force re-building the TensorRT engine. + static_batch (bool): + Build engine only for specified opt_batch_size. + static_shape (bool): + Build engine only for specified opt_image_height & opt_image_width. Default = True. + enable_refit (bool): + Build engines with refit option enabled. + enable_preview (bool): + Enable TensorRT preview features. + enable_all_tactics (bool): + Enable all tactic sources during TensorRT engine builds. + timing_cache (str): + Path to the timing cache to accelerate build or None + onnx_refit_dir (str): + Directory containing refit ONNX models. + """ + # Create directory + for directory in [engine_dir, onnx_dir]: + if not os.path.exists(directory): + print(f"[I] Create directory: {directory}") + pathlib.Path(directory).mkdir(parents=True) + + self.load_models(framework_model_dir) + + # Export models to ONNX + for model_name, obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + profile_id = obj.get_profile_id( + opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape + ) + engine_path = self.get_engine_path(engine_dir, model_name, profile_id) + if force_export or force_build or not os.path.exists(engine_path): + onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False) + onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) + if force_export or not os.path.exists(onnx_opt_path): + if force_export or not os.path.exists(onnx_path): + print(f"Exporting model: {onnx_path}") + model = obj.load_model(framework_model_dir, self.hf_token) + with torch.inference_mode(), torch.autocast("cuda"): + inputs = obj.get_sample_input(1, opt_image_height, opt_image_width) + torch.onnx.export( + model, + inputs, + onnx_path, + export_params=True, + opset_version=onnx_opset, + do_constant_folding=True, + input_names=obj.get_input_names(), + output_names=obj.get_output_names(), + dynamic_axes=obj.get_dynamic_axes(), + ) + del model + torch.cuda.empty_cache() + gc.collect() + else: + print(f"Found cached model: {onnx_path}") + + # Optimize onnx + if force_optimize or not os.path.exists(onnx_opt_path): + print(f"Generating optimizing model: {onnx_opt_path}") + obj.optimize_trt(onnx_path, onnx_opt_path) + else: + print(f"Found cached optimized model: {onnx_opt_path} ") + + # Build TensorRT engines + for model_name, obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + profile_id = obj.get_profile_id( + opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape + ) + engine_path = self.get_engine_path(engine_dir, model_name, profile_id) + engine = TensorrtEngine(engine_path) + onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) + + if force_build or not os.path.exists(engine.engine_path): + engine.build( + onnx_opt_path, + fp16=True, + input_profile=obj.get_input_profile( + opt_batch_size, + opt_image_height, + opt_image_width, + static_batch, + static_shape, + ), + enable_refit=enable_refit, + enable_preview=enable_preview, + enable_all_tactics=enable_all_tactics, + timing_cache=timing_cache, + update_output_names=None, + ) + self.engines[model_name] = engine + + # Load TensorRT engines + for model_name in self.models: + if model_name == "vae" and self.vae_torch_fallback: + continue + self.engines[model_name].load() + if onnx_refit_dir: + onnx_refit_path = self.get_onnx_path(model_name, onnx_refit_dir, opt=True) + if os.path.exists(onnx_refit_path): + self.engines[model_name].refit(onnx_opt_path, onnx_refit_path) + + def max_device_memory(self): + max_device_memory = 0 + for _model_name, engine in self.engines.items(): + max_device_memory = max(max_device_memory, engine.engine.device_memory_size) + return max_device_memory + + def activate_engines(self, shared_device_memory=None): + if shared_device_memory is None: + max_device_memory = self.max_device_memory() + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + self.shared_device_memory = shared_device_memory + # Load and activate TensorRT engines + for engine in self.engines.values(): + engine.activate(reuse_device_memory=self.shared_device_memory) + + def run_engine(self, model_name, feed_dict): + return self.engines[model_name].infer(feed_dict, self.stream, use_cuda_graph=self.use_cuda_graph) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/models.py deleted file mode 100644 index 0f7688a3df..0000000000 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/models.py +++ /dev/null @@ -1,368 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -# -# Copyright 2023 The HuggingFace Inc. team. -# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Models used in Stable diffusion. -""" -import logging - -import onnx -import onnx_graphsurgeon as gs -import torch -from onnx import shape_inference -from ort_optimizer import OrtStableDiffusionOptimizer -from polygraphy.backend.onnx.loader import fold_constants - -logger = logging.getLogger(__name__) - - -class TrtOptimizer: - def __init__(self, onnx_graph): - self.graph = gs.import_onnx(onnx_graph) - - def cleanup(self): - self.graph.cleanup().toposort() - - def get_optimized_onnx_graph(self): - return gs.export_onnx(self.graph) - - def select_outputs(self, keep, names=None): - self.graph.outputs = [self.graph.outputs[o] for o in keep] - if names: - for i, name in enumerate(names): - self.graph.outputs[i].name = name - - def fold_constants(self): - onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True) - self.graph = gs.import_onnx(onnx_graph) - - def infer_shapes(self): - onnx_graph = gs.export_onnx(self.graph) - if onnx_graph.ByteSize() > 2147483648: - raise TypeError("ERROR: model size exceeds supported 2GB limit") - else: - onnx_graph = shape_inference.infer_shapes(onnx_graph) - - self.graph = gs.import_onnx(onnx_graph) - - -class BaseModel: - def __init__(self, model, name, device="cuda", fp16=False, max_batch_size=16, embedding_dim=768, text_maxlen=77): - self.model = model - self.name = name - self.fp16 = fp16 - self.device = device - - self.min_batch = 1 - self.max_batch = max_batch_size - self.min_image_shape = 256 # min image resolution: 256x256 - self.max_image_shape = 1024 # max image resolution: 1024x1024 - self.min_latent_shape = self.min_image_shape // 8 - self.max_latent_shape = self.max_image_shape // 8 - - self.embedding_dim = embedding_dim - self.text_maxlen = text_maxlen - - self.model_type = name.lower() if name in ["CLIP", "UNet"] else "vae" - self.ort_optimizer = OrtStableDiffusionOptimizer(self.model_type) - - def get_model(self): - return self.model - - def get_input_names(self): - pass - - def get_output_names(self): - pass - - def get_dynamic_axes(self): - return None - - def get_sample_input(self, batch_size, image_height, image_width): - pass - - def get_profile_id(self, batch_size, image_height, image_width, static_batch, static_image_shape): - """For TensorRT EP""" - ( - min_batch, - max_batch, - min_image_height, - max_image_height, - min_image_width, - max_image_width, - _, - _, - _, - _, - ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) - - profile_id = f"_b_{batch_size}" if static_batch else f"_b_{min_batch}_{max_batch}" - - if self.name != "CLIP": - if static_image_shape: - profile_id += f"_h_{image_height}_w_{image_width}" - else: - profile_id += f"_h_{min_image_height}_{max_image_height}_w_{min_image_width}_{max_image_width}" - - return profile_id - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): - """For TensorRT""" - return None - - def get_shape_dict(self, batch_size, image_height, image_width): - return None - - def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True): - self.ort_optimizer.optimize(input_onnx_path, optimized_onnx_path, to_fp16) - - def optimize_trt(self, input_onnx_path, optimized_onnx_path): - onnx_graph = onnx.load(input_onnx_path) - opt = TrtOptimizer(onnx_graph) - opt.cleanup() - opt.fold_constants() - opt.infer_shapes() - opt.cleanup() - onnx_opt_graph = opt.get_optimized_onnx_graph() - onnx.save(onnx_opt_graph, optimized_onnx_path) - - def check_dims(self, batch_size, image_height, image_width): - assert batch_size >= self.min_batch and batch_size <= self.max_batch - assert image_height % 8 == 0 or image_width % 8 == 0 - latent_height = image_height // 8 - latent_width = image_width // 8 - assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape - assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape - return (latent_height, latent_width) - - def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape): - min_batch = batch_size if static_batch else self.min_batch - max_batch = batch_size if static_batch else self.max_batch - latent_height = image_height // 8 - latent_width = image_width // 8 - min_image_height = image_height if static_image_shape else self.min_image_shape - max_image_height = image_height if static_image_shape else self.max_image_shape - min_image_width = image_width if static_image_shape else self.min_image_shape - max_image_width = image_width if static_image_shape else self.max_image_shape - min_latent_height = latent_height if static_image_shape else self.min_latent_shape - max_latent_height = latent_height if static_image_shape else self.max_latent_shape - min_latent_width = latent_width if static_image_shape else self.min_latent_shape - max_latent_width = latent_width if static_image_shape else self.max_latent_shape - return ( - min_batch, - max_batch, - min_image_height, - max_image_height, - min_image_width, - max_image_width, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) - - -class CLIP(BaseModel): - def __init__(self, model, device, max_batch_size, embedding_dim): - super().__init__( - model=model, - name="CLIP", - device=device, - max_batch_size=max_batch_size, - embedding_dim=embedding_dim, - ) - - def get_input_names(self): - return ["input_ids"] - - def get_output_names(self): - return ["text_embeddings"] - - def get_dynamic_axes(self): - return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): - self.check_dims(batch_size, image_height, image_width) - min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( - batch_size, image_height, image_width, static_batch, static_image_shape - ) - return { - "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)] - } - - def get_shape_dict(self, batch_size, image_height, image_width): - self.check_dims(batch_size, image_height, image_width) - return { - "input_ids": (batch_size, self.text_maxlen), - "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), - } - - def get_sample_input(self, batch_size, image_height, image_width): - self.check_dims(batch_size, image_height, image_width) - return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device) - - def optimize_trt(self, input_onnx_path, optimized_onnx_path): - onnx_graph = onnx.load(input_onnx_path) - opt = TrtOptimizer(onnx_graph) - opt.select_outputs([0]) # delete graph output#1 - opt.cleanup() - opt.fold_constants() - opt.infer_shapes() - opt.select_outputs([0], names=["text_embeddings"]) # rename network output - opt.cleanup() - onnx_opt_graph = opt.get_optimized_onnx_graph() - onnx.save(onnx_opt_graph, optimized_onnx_path) - - -class UNet(BaseModel): - def __init__( - self, - model, - device="cuda", - fp16=False, # used by TRT - max_batch_size=16, - embedding_dim=768, - text_maxlen=77, - unet_dim=4, - ): - super().__init__( - model=model, - name="UNet", - device=device, - fp16=fp16, - max_batch_size=max_batch_size, - embedding_dim=embedding_dim, - text_maxlen=text_maxlen, - ) - self.unet_dim = unet_dim - - def get_input_names(self): - return ["sample", "timestep", "encoder_hidden_states"] - - def get_output_names(self): - return ["latent"] - - def get_dynamic_axes(self): - return { - "sample": {0: "2B", 2: "H", 3: "W"}, - "encoder_hidden_states": {0: "2B"}, - "latent": {0: "2B", 2: "H", 3: "W"}, - } - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - ( - min_batch, - max_batch, - _, - _, - _, - _, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) - return { - "sample": [ - (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), - (2 * batch_size, self.unet_dim, latent_height, latent_width), - (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), - ], - "encoder_hidden_states": [ - (2 * min_batch, self.text_maxlen, self.embedding_dim), - (2 * batch_size, self.text_maxlen, self.embedding_dim), - (2 * max_batch, self.text_maxlen, self.embedding_dim), - ], - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return { - "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), - "timestep": [1], - "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), - "latent": (2 * batch_size, 4, latent_height, latent_width), - } - - def get_sample_input(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - dtype = torch.float16 if self.fp16 else torch.float32 - return ( - torch.randn( - 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device - ), - torch.tensor([1.0], dtype=torch.float32, device=self.device), - torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), - ) - - -class VAE(BaseModel): - def __init__(self, model, device, max_batch_size, embedding_dim): - super().__init__( - model=model, - name="VAE Decoder", - device=device, - max_batch_size=max_batch_size, - embedding_dim=embedding_dim, - ) - - def get_input_names(self): - return ["latent"] - - def get_output_names(self): - return ["images"] - - def get_dynamic_axes(self): - return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}} - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - ( - min_batch, - max_batch, - _, - _, - _, - _, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) - return { - "latent": [ - (min_batch, 4, min_latent_height, min_latent_width), - (batch_size, 4, latent_height, latent_width), - (max_batch, 4, max_latent_height, max_latent_width), - ] - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return { - "latent": (batch_size, 4, latent_height, latent_width), - "images": (batch_size, 3, image_height, image_width), - } - - def get_sample_input(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py index 6134fa7bdd..37785869a3 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py @@ -43,16 +43,14 @@ StableDiffusionSafetyChecker, ) from diffusers.schedulers import DDIMScheduler -from diffusers.utils import DIFFUSERS_CACHE -from huggingface_hub import snapshot_download -from models import CLIP, VAE, UNet -from ort_utils import Engines +from diffusion_models import CLIP, VAE, PipelineInfo, UNet +from ort_utils import Engines, StableDiffusionPipelineMixin from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer logger = logging.getLogger(__name__) -class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline): +class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipelineMixin, StableDiffusionPipeline): r""" Pipeline for text-to-image generation using CUDA provider in ONNX Runtime. This pipeline inherits from [`StableDiffusionPipeline`]. Check the documentation in super class for most parameters. @@ -70,11 +68,12 @@ def __init__( requires_safety_checker: bool = True, # ONNX export parameters onnx_opset: int = 14, - onnx_dir: str = "raw_onnx", + onnx_dir: str = "onnx_ort", # Onnxruntime execution provider parameters - engine_dir: str = "onnxruntime_optimized_onnx", + engine_dir: str = "ORT_CUDA", force_engine_rebuild: bool = False, enable_cuda_graph: bool = False, + pipeline_info: PipelineInfo = None, ): super().__init__( vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker @@ -96,51 +95,38 @@ def __init__( self.fp16 = False - def __load_models(self): - self.embedding_dim = self.text_encoder.config.hidden_size + self.pipeline_info = pipeline_info - self.models["clip"] = CLIP( - self.text_encoder, - device=self.torch_device, - max_batch_size=self.max_batch_size, - embedding_dim=self.embedding_dim, - ) + def load_models(self): + assert self.pipeline_info.clip_embedding_dim() == self.text_encoder.config.hidden_size - self.models["unet"] = UNet( - self.unet, - device=self.torch_device, - fp16=self.fp16, - max_batch_size=self.max_batch_size, - embedding_dim=self.embedding_dim, - unet_dim=(9 if self.inpaint else 4), - ) + stages = self.pipeline_info.stages() + if "clip" in stages: + self.models["clip"] = CLIP( + self.pipeline_info, + self.text_encoder, + device=self.torch_device, + max_batch_size=self.max_batch_size, + clip_skip=0, + ) - self.models["vae"] = VAE( - self.vae, device=self.torch_device, max_batch_size=self.max_batch_size, embedding_dim=self.embedding_dim - ) + if "unet" in stages: + self.models["unet"] = UNet( + self.pipeline_info, + self.unet, + device=self.torch_device, + fp16=False, + max_batch_size=self.max_batch_size, + unet_dim=(9 if self.pipeline_info.is_inpaint() else 4), + ) - @classmethod - def set_cached_folder(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", False) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - - cls.cached_folder = ( - pretrained_model_name_or_path - if os.path.isdir(pretrained_model_name_or_path) - else snapshot_download( - pretrained_model_name_or_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, + if "vae" in stages: + self.models["vae"] = VAE( + self.pipeline_info, + self.vae, + device=self.torch_device, + max_batch_size=self.max_batch_size, ) - ) def to( self, @@ -156,7 +142,7 @@ def to( # load models self.fp16 = torch_dtype == torch.float16 - self.__load_models() + self.load_models() # build engines self.engines.build( @@ -180,88 +166,6 @@ def to( return self - def __encode_prompt(self, prompt, negative_prompt): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - """ - # Tokenize prompt - text_input_ids = ( - self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - .input_ids.type(torch.int32) - .to(self.torch_device) - ) - - # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt - text_embeddings = ( - self.engines.get_engine("clip").infer({"input_ids": text_input_ids})["text_embeddings"].clone() - ) - - # Tokenize negative prompt - uncond_input_ids = ( - self.tokenizer( - negative_prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - .input_ids.type(torch.int32) - .to(self.torch_device) - ) - - uncond_embeddings = self.engines.get_engine("clip").infer({"input_ids": uncond_input_ids})["text_embeddings"] - - # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) - - return text_embeddings - - def __denoise_latent(self, latents, text_embeddings, timesteps=None, mask=None, masked_image_latents=None): - if not isinstance(timesteps, torch.Tensor): - timesteps = self.scheduler.timesteps - - for _step_index, timestep in enumerate(timesteps): - # Expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) - if isinstance(mask, torch.Tensor): - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) - - timestep_float = timestep.to(torch.float16) if self.fp16 else timestep.to(torch.float32) - - # Predict the noise residual - noise_pred = self.engines.get_engine("unet").infer( - {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings}, - )["latent"] - - # Perform guidance - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample - - latents = 1.0 / 0.18215 * latents - return latents - - def __decode_latent(self, latents): - images = self.engines.get_engine("vae").infer({"latent": latents})["images"] - images = (images / 2 + 0.5).clamp(0, 1) - return images.cpu().permute(0, 2, 3, 1).float().numpy() - def __allocate_buffers(self, image_height, image_width, batch_size): # Allocate output tensors for I/O bindings for model_name, obj in self.models.items(): @@ -337,7 +241,7 @@ def __call__( with torch.inference_mode(), torch.autocast("cuda"): # CLIP text encoder - text_embeddings = self.__encode_prompt(prompt, negative_prompt) + text_embeddings = self.encode_prompt(self.engines.get_engine("clip"), prompt, negative_prompt) # Pre-initialize latents num_channels_latents = self.unet_in_channels @@ -352,30 +256,37 @@ def __call__( ) # UNet denoiser - latents = self.__denoise_latent(latents, text_embeddings) + latents = self.denoise_latent( + self.engines.get_engine("unet"), latents, text_embeddings, timestep_fp16=self.fp16 + ) # VAE decode latent - images = self.__decode_latent(latents) + images = self.decode_latent(self.engines.get_engine("vae"), latents) images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype) images = self.numpy_to_pil(images) return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) -if __name__ == "__main__": - model_name_or_path = "runwayml/stable-diffusion-v1-5" +def example(): + pipeline_info = PipelineInfo("1.5") + model_name_or_path = pipeline_info.name() scheduler = DDIMScheduler.from_pretrained(model_name_or_path, subfolder="scheduler") - pipe = OnnxruntimeCudaStableDiffusionPipeline.from_pretrained( model_name_or_path, scheduler=scheduler, + pipeline_info=pipeline_info, ) # re-use cached folder to save ONNX models - pipe.set_cached_folder(model_name_or_path) + pipe.set_cached_folder(model_name_or_path, resume_download=True, local_files_only=True) pipe = pipe.to("cuda", torch_dtype=torch.float16) prompt = "photorealistic new zealand hills" image = pipe(prompt).images[0] - image.save("ort_trt_txt2img_new_zealand_hills.png") + image.save("ort_cuda_txt2img_new_zealand_hills.png") + + +if __name__ == "__main__": + example() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py index 6f3c215f36..c663e37c7e 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py @@ -32,13 +32,11 @@ pip install onnxruntime-gpu """ -import gc +import logging import os -import shutil from typing import List, Optional, Union import torch -from cuda import cudart from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import ( StableDiffusionPipeline, @@ -46,224 +44,15 @@ StableDiffusionSafetyChecker, ) from diffusers.schedulers import DDIMScheduler -from diffusers.utils import DIFFUSERS_CACHE, logging -from huggingface_hub import snapshot_download -from models import CLIP, VAE, UNet -from ort_utils import OrtCudaSession +from diffusion_models import PipelineInfo +from engine_builder_ort_trt import OrtTensorrtEngineBuilder +from ort_utils import StableDiffusionPipelineMixin from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -import onnxruntime as ort +logger = logging.getLogger(__name__) -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -class Engine(OrtCudaSession): - def __init__(self, engine_path, device_id, onnx_path, fp16, input_profile, workspace_size, enable_cuda_graph): - self.engine_path = engine_path - self.ort_trt_provider_options = self.get_tensorrt_provider_options( - input_profile, - workspace_size, - fp16, - device_id, - enable_cuda_graph, - ) - - sess_options = ort.SessionOptions() - sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL - ort_session = ort.InferenceSession( - onnx_path, - sess_options, - providers=[ - ("TensorrtExecutionProvider", self.ort_trt_provider_options), - ], - ) - - device = torch.device("cuda", device_id) - super().__init__(ort_session, device, enable_cuda_graph) - - def get_tensorrt_provider_options(self, input_profile, workspace_size, fp16, device_id, enable_cuda_graph): - trt_ep_options = { - "device_id": device_id, - "trt_fp16_enable": fp16, - "trt_engine_cache_enable": True, - "trt_timing_cache_enable": True, - "trt_detailed_build_log": True, - "trt_engine_cache_path": self.engine_path, - } - - if enable_cuda_graph: - trt_ep_options["trt_cuda_graph_enable"] = True - - if workspace_size > 0: - trt_ep_options["trt_max_workspace_size"] = workspace_size - - if input_profile: - min_shapes = [] - max_shapes = [] - opt_shapes = [] - for name, profile in input_profile.items(): - assert isinstance(profile, list) and len(profile) == 3 - min_shape = profile[0] - opt_shape = profile[1] - max_shape = profile[2] - assert len(min_shape) == len(opt_shape) and len(opt_shape) == len(max_shape) - - min_shapes.append(f"{name}:" + "x".join([str(x) for x in min_shape])) - opt_shapes.append(f"{name}:" + "x".join([str(x) for x in opt_shape])) - max_shapes.append(f"{name}:" + "x".join([str(x) for x in max_shape])) - - trt_ep_options["trt_profile_min_shapes"] = ",".join(min_shapes) - trt_ep_options["trt_profile_max_shapes"] = ",".join(max_shapes) - trt_ep_options["trt_profile_opt_shapes"] = ",".join(opt_shapes) - - logger.info("trt_ep_options=%s", trt_ep_options) - - return trt_ep_options - - -def get_onnx_path(model_name, onnx_dir, opt=True): - return os.path.join(onnx_dir, model_name + (".opt" if opt else "") + ".onnx") - - -def get_engine_path(engine_dir, model_name, profile_id): - return os.path.join(engine_dir, model_name + profile_id) - - -def has_engine_file(engine_path): - if os.path.isdir(engine_path): - children = os.scandir(engine_path) - for entry in children: - if entry.is_file() and entry.name.endswith(".engine"): - return True - return False - - -def get_work_space_size(model_name, max_workspace_size): - gibibyte = 2**30 - workspace_size = 4 * gibibyte if model_name == "clip" else max_workspace_size - if workspace_size == 0: - _, free_mem, _ = cudart.cudaMemGetInfo() - # The following logic are adopted from TensorRT demo diffusion. - if free_mem > 6 * gibibyte: - workspace_size = free_mem - 4 * gibibyte - return workspace_size - - -def build_engines( - models, - engine_dir, - onnx_dir, - onnx_opset, - opt_image_height, - opt_image_width, - opt_batch_size=1, - force_engine_rebuild=False, - static_batch=False, - static_image_shape=True, - max_workspace_size=0, - device_id=0, - enable_cuda_graph=False, -): - if force_engine_rebuild: - if os.path.isdir(onnx_dir): - logger.info("Remove existing directory %s since force_engine_rebuild is enabled", onnx_dir) - shutil.rmtree(onnx_dir) - if os.path.isdir(engine_dir): - logger.info("Remove existing directory %s since force_engine_rebuild is enabled", engine_dir) - shutil.rmtree(engine_dir) - - if not os.path.isdir(engine_dir): - os.makedirs(engine_dir) - - if not os.path.isdir(onnx_dir): - os.makedirs(onnx_dir) - - # Export models to ONNX - for model_name, model_obj in models.items(): - profile_id = model_obj.get_profile_id( - opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape - ) - engine_path = get_engine_path(engine_dir, model_name, profile_id) - if not has_engine_file(engine_path): - onnx_path = get_onnx_path(model_name, onnx_dir, opt=False) - onnx_opt_path = get_onnx_path(model_name, onnx_dir) - if not os.path.exists(onnx_opt_path): - if not os.path.exists(onnx_path): - logger.info(f"Exporting model: {onnx_path}") - model = model_obj.get_model() - with torch.inference_mode(), torch.autocast("cuda"): - inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) - torch.onnx.export( - model, - inputs, - onnx_path, - export_params=True, - opset_version=onnx_opset, - do_constant_folding=True, - input_names=model_obj.get_input_names(), - output_names=model_obj.get_output_names(), - dynamic_axes=model_obj.get_dynamic_axes(), - ) - del model - torch.cuda.empty_cache() - gc.collect() - else: - logger.info("Found cached model: %s", onnx_path) - - # Optimize onnx - if not os.path.exists(onnx_opt_path): - logger.info("Generating optimizing model: %s", onnx_opt_path) - model_obj.optimize_trt(onnx_path, onnx_opt_path) - else: - logger.info("Found cached optimized model: %s", onnx_opt_path) - - built_engines = {} - for model_name, model_obj in models.items(): - profile_id = model_obj.get_profile_id( - opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape - ) - - engine_path = get_engine_path(engine_dir, model_name, profile_id) - onnx_opt_path = get_onnx_path(model_name, onnx_dir) - - if not has_engine_file(engine_path): - logger.info( - "Building TensorRT engine for %s from %s to %s. It can take a while to complete...", - model_name, - onnx_opt_path, - engine_path, - ) - else: - logger.info("Reuse cached TensorRT engine in directory %s", engine_path) - - input_profile = model_obj.get_input_profile( - opt_batch_size, - opt_image_height, - opt_image_width, - static_batch=static_batch, - static_image_shape=static_image_shape, - ) - - engine = Engine( - engine_path, - device_id, - onnx_opt_path, - fp16=True, - input_profile=input_profile, - workspace_size=get_work_space_size(model_name, max_workspace_size), - enable_cuda_graph=enable_cuda_graph, - ) - - built_engines[model_name] = engine - - return built_engines - - -def run_engine(engine, feed_dict): - return engine.infer(feed_dict) - - -class OnnxruntimeTensorRTStableDiffusionPipeline(StableDiffusionPipeline): +class OnnxruntimeTensorRTStableDiffusionPipeline(StableDiffusionPipelineMixin, StableDiffusionPipeline): r""" Pipeline for text-to-image generation using TensorRT execution provider in ONNX Runtime. @@ -285,11 +74,12 @@ def __init__( max_batch_size: int = 16, # ONNX export parameters onnx_opset: int = 17, - onnx_dir: str = "onnx", + onnx_dir: str = "onnx_trt", # TensorRT engine build parameters - engine_dir: str = "onnxruntime_tensorrt_engine", + engine_dir: str = "ORT_TRT", # use short name here to avoid path exceeds 260 chars in Windows. force_engine_rebuild: bool = False, enable_cuda_graph: bool = False, + pipeline_info: Optional[PipelineInfo] = None, ): super().__init__( vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker @@ -299,16 +89,14 @@ def __init__( self.image_height = image_height self.image_width = image_width - self.inpaint = False self.onnx_opset = onnx_opset self.onnx_dir = onnx_dir self.engine_dir = engine_dir self.force_engine_rebuild = force_engine_rebuild - self.enable_cuda_graph = enable_cuda_graph - # Although cuda graph requires static input shape, engine built with dyamic batch gets better performance in T4. + # Although cuda graph requires static input shape, engine built with dynamic batch gets better performance in T4. # Use static batch could reduce GPU memory footprint. - self.build_static_batch = False + self.build_static_batch = enable_cuda_graph # TODO: support dynamic image shape. self.build_dynamic_shape = False @@ -318,54 +106,13 @@ def __init__( if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512: self.max_batch_size = 4 - self.models = {} # loaded in __load_models() self.engines = {} # loaded in build_engines() - - def __load_models(self): - self.embedding_dim = self.text_encoder.config.hidden_size - - self.models["clip"] = CLIP( - self.text_encoder, - device=self.torch_device, - max_batch_size=self.max_batch_size, - embedding_dim=self.embedding_dim, - ) - - self.models["unet"] = UNet( - self.unet, - device=self.torch_device, - fp16=True, - max_batch_size=self.max_batch_size, - embedding_dim=self.embedding_dim, - unet_dim=(9 if self.inpaint else 4), + self.engine_builder = OrtTensorrtEngineBuilder( + pipeline_info, max_batch_size=max_batch_size, use_cuda_graph=enable_cuda_graph ) - self.models["vae"] = VAE( - self.vae, device=self.torch_device, max_batch_size=self.max_batch_size, embedding_dim=self.embedding_dim - ) - - @classmethod - def set_cached_folder(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", False) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - - cls.cached_folder = ( - pretrained_model_name_or_path - if os.path.isdir(pretrained_model_name_or_path) - else snapshot_download( - pretrained_model_name_or_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - ) - ) + self.pipeline_info = pipeline_info + self.stages = pipeline_info.stages() def to( self, @@ -381,11 +128,9 @@ def to( self.torch_device = self._execution_device logger.info(f"Running inference on device: {self.torch_device}") - self.__load_models() - - self.engines = build_engines( - self.models, + self.engines = self.engine_builder.build_engines( self.engine_dir, + None, self.onnx_dir, self.onnx_opset, opt_image_height=self.image_height, @@ -394,96 +139,10 @@ def to( static_batch=self.build_static_batch, static_image_shape=not self.build_dynamic_shape, device_id=self.torch_device.index, - enable_cuda_graph=self.enable_cuda_graph, ) return self - def __encode_prompt(self, prompt, negative_prompt): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - """ - # Tokenize prompt - text_input_ids = ( - self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - .input_ids.type(torch.int32) - .to(self.torch_device) - ) - - # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt - text_embeddings = run_engine(self.engines["clip"], {"input_ids": text_input_ids})["text_embeddings"].clone() - - # Tokenize negative prompt - uncond_input_ids = ( - self.tokenizer( - negative_prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - .input_ids.type(torch.int32) - .to(self.torch_device) - ) - - uncond_embeddings = run_engine(self.engines["clip"], {"input_ids": uncond_input_ids})["text_embeddings"] - - # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) - - return text_embeddings - - def __denoise_latent(self, latents, text_embeddings, timesteps=None, mask=None, masked_image_latents=None): - if not isinstance(timesteps, torch.Tensor): - timesteps = self.scheduler.timesteps - for _step_index, timestep in enumerate(timesteps): - # Expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) - if isinstance(mask, torch.Tensor): - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) - - # Predict the noise residual - timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep - - noise_pred = run_engine( - self.engines["unet"], - {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings}, - )["latent"] - - # Perform guidance - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample - - latents = 1.0 / 0.18215 * latents - return latents - - def __decode_latent(self, latents): - images = run_engine(self.engines["vae"], {"latent": latents})["images"] - images = (images / 2 + 0.5).clamp(0, 1) - return images.cpu().permute(0, 2, 3, 1).float().numpy() - - def __allocate_buffers(self, image_height, image_width, batch_size): - # Allocate output tensors for I/O bindings - for model_name, obj in self.models.items(): - self.engines[model_name].allocate_buffers(obj.get_shape_dict(batch_size, image_height, image_width)) - @torch.no_grad() def __call__( self, @@ -547,11 +206,11 @@ def __call__( f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4" ) - self.__allocate_buffers(self.image_height, self.image_width, batch_size) + self.engine_builder.load_resources(self.image_height, self.image_width, batch_size) with torch.inference_mode(), torch.autocast("cuda"): # CLIP text encoder - text_embeddings = self.__encode_prompt(prompt, negative_prompt) + text_embeddings = self.encode_prompt(self.engines["clip"], prompt, negative_prompt) # Pre-initialize latents num_channels_latents = self.unet.config.in_channels @@ -566,10 +225,10 @@ def __call__( ) # UNet denoiser - latents = self.__denoise_latent(latents, text_embeddings) + latents = self.denoise_latent(self.engines["unet"], latents, text_embeddings) # VAE decode latent - images = self.__decode_latent(latents) + images = self.decode_latent(self.engines["vae"], latents) images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype) images = self.numpy_to_pil(images) @@ -577,8 +236,8 @@ def __call__( if __name__ == "__main__": - model_name_or_path = "runwayml/stable-diffusion-v1-5" - + pipeline_info = PipelineInfo("1.5") + model_name_or_path = pipeline_info.name() scheduler = DDIMScheduler.from_pretrained(model_name_or_path, subfolder="scheduler") pipe = OnnxruntimeTensorRTStableDiffusionPipeline.from_pretrained( @@ -589,6 +248,7 @@ def __call__( image_height=512, image_width=512, max_batch_size=4, + pipeline_info=pipeline_info, ) # re-use cached folder to save ONNX models and TensorRT Engines diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index a8e3c69332..cceed84a45 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -5,19 +5,15 @@ # # This script converts stable diffusion onnx models from float to half (mixed) precision for GPU inference. # -# Before running this script, follow README.md to setup python environment and convert stable diffusion checkpoint to float32 onnx models. +# Before running this script, follow README.md to setup python environment and convert stable diffusion checkpoint +# to float32 onnx models. # -# For example, the float32 ONNX pipeline is saved to ./sd-v1-5 directory, you can optimize and convert it to float16 like the following: +# For example, the float32 ONNX pipeline is saved to ./sd-v1-5 directory, you can optimize and convert it to float16 +# like the following: # python optimize_pipeline.py -i ./sd-v1-5 -o ./sd-v1-5-fp16 --float16 # -# Note that the optimizations are carried out for CUDA Execution Provider at first, other EPs may not have the support for the fused opeartors. -# In this case, the users should disable the operator fusion manually to workaround. -# -# Stable diffusion 2.1 model will get black images using float16 Attention. A walkaround is to force Attention to run in float32 like the following: -# python optimize_pipeline.py -i ./sd-v2-1 -o ./sd-v2-1-fp16 --float16 --force_fp32_ops unet:Attention -# -# If you are using nightly package (or built from source), you can force MultiHeadAttention to run in float32: -# python optimize_pipeline.py -i ./sd-v2-1 -o ./sd-v2-1-fp16 --float16 --force_fp32_ops unet:MultiHeadAttention +# Note that the optimizations are carried out for CUDA Execution Provider at first, other EPs may not have the support +# for the fused operators. The users could disable the operator fusion manually to workaround. import argparse import logging @@ -25,8 +21,9 @@ import shutil import tempfile from pathlib import Path -from typing import List +from typing import List, Optional +import __init__ # noqa: F401. Walk-around to run this script directly import coloredlogs import onnx from fusion_options import FusionOptions @@ -41,11 +38,18 @@ logger = logging.getLogger(__name__) -def optimize_sd_pipeline( +def has_external_data(onnx_model_path): + original_model = onnx.load_model(str(onnx_model_path), load_external_data=False) + for initializer in original_model.graph.initializer: + if initializer.HasField("data_location") and initializer.data_location == onnx.TensorProto.EXTERNAL: + return True + return False + + +def _optimize_sd_pipeline( source_dir: Path, target_dir: Path, - overwrite: bool, - use_external_data_format: bool, + use_external_data_format: Optional[bool], float16: bool, force_fp32_ops: List[str], enable_runtime_optimization: bool, @@ -56,8 +60,7 @@ def optimize_sd_pipeline( Args: source_dir (Path): Root of input directory of stable diffusion onnx pipeline with float32 models. target_dir (Path): Root of output directory of stable diffusion onnx pipeline with optimized models. - overwrite (bool): Overwrite files if exists. - use_external_data_format (bool): save onnx model to two files: one for onnx graph, another for weights + use_external_data_format (Optional[bool]): use external data format. float16 (bool): use half precision force_fp32_ops(List[str]): operators that are forced to run in float32. enable_runtime_optimization(bool): run graph optimization using Onnx Runtime. @@ -71,6 +74,7 @@ def optimize_sd_pipeline( "vae_encoder": "vae", "vae_decoder": "vae", "text_encoder": "clip", + "text_encoder_2": "clip", "safety_checker": "unet", } @@ -85,9 +89,12 @@ def optimize_sd_pipeline( "vae_encoder": [], "vae_decoder": [], "text_encoder": [], + "text_encoder_2": [], "safety_checker": [], } + is_xl = (source_dir / "text_encoder_2").exists() + if force_fp32_ops: for fp32_operator in force_fp32_ops: parts = fp32_operator.split(":") @@ -100,26 +107,21 @@ def optimize_sd_pipeline( for name, model_type in model_type_mapping.items(): onnx_model_path = source_dir / name / "model.onnx" - if not os.path.exists(onnx_model_path): - message = f"input onnx model does not exist: {onnx_model_path}." - if name not in ["safety_checker"]: - raise RuntimeError(message) + if name != "safety_checker": + logger.info("input onnx model does not exist: %s", onnx_model_path) + # some model are optional so we do not raise error here. continue # Prepare output directory optimized_model_path = target_dir / name / "model.onnx" output_dir = optimized_model_path.parent - if optimized_model_path.exists(): - if not overwrite: - raise RuntimeError(f"output onnx model path existed: {optimized_model_path}") - - if output_dir.exists(): - shutil.rmtree(output_dir) output_dir.mkdir(parents=True, exist_ok=True) + if use_external_data_format is None: + use_external_data_format = has_external_data(onnx_model_path) + # Graph fusion before fp16 conversion, otherwise they cannot be fused later. - # Right now, onnxruntime does not save >2GB model so we use script to optimize unet instead. logger.info(f"Optimize {onnx_model_path}...") args.model_type = model_type @@ -143,12 +145,15 @@ def optimize_sd_pipeline( ) if float16: - logger.info("Convert %s to float16 ...", name) - op_block_list = ["RandomNormalLike"] - m.convert_float_to_float16( - keep_io_types=False, - op_block_list=op_block_list + force_fp32_operators[name], - ) + # For SD-XL, use FP16 in VAE decoder will cause NaN and black image so we keep it in FP32. + if is_xl and name == "vae_decoder": + logger.info("Skip converting %s to float16 to avoid NaN", name) + else: + logger.info("Convert %s to float16 ...", name) + m.convert_float_to_float16( + keep_io_types=False, + op_block_list=force_fp32_operators[name], + ) if enable_runtime_optimization and (float16 or (name not in ["unet"])): # Use this step to see the final graph that executed by Onnx Runtime. @@ -173,35 +178,24 @@ def optimize_sd_pipeline( logger.info("*" * 20) -def copy_extra_directory(source_dir: Path, target_dir: Path, overwrite: bool): +def _copy_extra_directory(source_dir: Path, target_dir: Path): """Copy extra directory that does not have onnx model Args: source_dir (Path): source directory target_dir (Path): target directory - overwrite (bool): overwrite if exists Raises: RuntimeError: source path does not exist - RuntimeError: output path exists but overwrite is false. """ - extra_dirs = ["scheduler", "tokenizer", "feature_extractor"] + extra_dirs = ["scheduler", "tokenizer", "tokenizer_2", "feature_extractor"] for name in extra_dirs: source_path = source_dir / name - if not os.path.exists(source_path): - message = f"source path does not exist: {source_path}" - if name not in ["feature_extractor"]: - raise RuntimeError(message) continue target_path = target_dir / name - if target_path.exists(): - if not overwrite: - raise RuntimeError(f"output path existed: {target_path}") - shutil.rmtree(target_path) - shutil.copytree(source_path, target_path) logger.info("%s => %s", source_path, target_path) @@ -212,15 +206,53 @@ def copy_extra_directory(source_dir: Path, target_dir: Path, overwrite: bool): raise RuntimeError(f"source path does not exist: {source_path}") target_path = target_dir / name - if target_path.exists(): - if not overwrite: - raise RuntimeError(f"output path existed: {target_path}") - os.remove(target_path) shutil.copyfile(source_path, target_path) logger.info("%s => %s", source_path, target_path) + # Some directory are optional + onnx_model_dirs = ["text_encoder", "text_encoder_2", "unet", "vae_encoder", "vae_decoder", "safety_checker"] + for onnx_model_dir in onnx_model_dirs: + source_path = source_dir / onnx_model_dir / "config.json" + target_path = target_dir / onnx_model_dir / "config.json" + if source_path.exists(): + target_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copyfile(source_path, target_path) + logger.info("%s => %s", source_path, target_path) + -def parse_arguments(): +def optimize_stable_diffusion_pipeline( + input_dir: str, + output_dir: str, + overwrite: bool, + use_external_data_format: Optional[bool], + float16: bool, + enable_runtime_optimization: bool, + args, +): + if os.path.exists(output_dir): + if overwrite: + shutil.rmtree(output_dir, ignore_errors=True) + else: + raise RuntimeError("output directory existed:{output_dir}. Add --overwrite to empty the directory.") + + source_dir = Path(input_dir) + target_dir = Path(output_dir) + target_dir.mkdir(parents=True, exist_ok=True) + + _copy_extra_directory(source_dir, target_dir) + + _optimize_sd_pipeline( + source_dir, + target_dir, + use_external_data_format, + float16, + args.force_fp32_ops, + enable_runtime_optimization, + args, + ) + + +def parse_arguments(argv: Optional[List[str]] = None): """Parse arguments Returns: @@ -264,7 +296,8 @@ def parse_arguments(): "--inspect", required=False, action="store_true", - help="Inspect the optimized graph from Onnx Runtime for debugging purpose. This option has no impact on model performance.", + help="Save the optimized graph from Onnx Runtime. " + "This option has no impact on inference performance except it might reduce session creation time.", ) parser.set_defaults(inspect=False) @@ -282,32 +315,25 @@ def parse_arguments(): required=False, action="store_true", help="Onnx model larger than 2GB need to use external data format. " - "Save onnx model to two files: one for onnx graph, another for large weights.", + "If specified, save each onnx model to two files: one for onnx graph, another for weights. " + "If not specified, use same format as original model by default. ", ) - parser.set_defaults(use_external_data_format=False) + parser.set_defaults(use_external_data_format=None) FusionOptions.add_arguments(parser) - args = parser.parse_args() + args = parser.parse_args(argv) return args -def main(): - coloredlogs.install(fmt="%(funcName)20s: %(message)s") - args = parse_arguments() +def main(argv: Optional[List[str]] = None): + args = parse_arguments(argv) logger.info("Arguments: %s", str(args)) - copy_extra_directory(Path(args.input), Path(args.output), args.overwrite) - optimize_sd_pipeline( - Path(args.input), - Path(args.output), - args.overwrite, - args.use_external_data_format, - args.float16, - args.force_fp32_ops, - args.inspect, - args, + optimize_stable_diffusion_pipeline( + args.input, args.output, args.overwrite, args.use_external_data_format, args.float16, args.inspect, args ) if __name__ == "__main__": + coloredlogs.install(fmt="%(funcName)20s: %(message)s") main() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py index 0824c8f07d..4b48396b6c 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py @@ -12,6 +12,7 @@ from pathlib import Path import onnx +from packaging import version from onnxruntime.transformers.fusion_options import FusionOptions from onnxruntime.transformers.onnx_model_clip import ClipOnnxModel @@ -32,53 +33,84 @@ def __init__(self, model_type: str): "clip": ClipOnnxModel, } - def optimize_by_ort(self, onnx_model): + def optimize_by_ort(self, onnx_model, use_external_data_format=False): # Use this step to see the final graph that executed by Onnx Runtime. with tempfile.TemporaryDirectory() as tmp_dir: # Save to a temporary file so that we can load it with Onnx Runtime. logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...") tmp_model_path = Path(tmp_dir) / "model.onnx" - onnx_model.save_model_to_file(str(tmp_model_path)) - ort_optimized_model_path = tmp_model_path + onnx_model.save_model_to_file(str(tmp_model_path), use_external_data_format=use_external_data_format) + ort_optimized_model_path = Path(tmp_dir) / "optimized.onnx" optimize_by_onnxruntime( - str(tmp_model_path), use_gpu=True, optimized_model_path=str(ort_optimized_model_path) + str(tmp_model_path), + use_gpu=True, + optimized_model_path=str(ort_optimized_model_path), + save_as_external_data=use_external_data_format, + external_data_filename="optimized.onnx_data", ) model = onnx.load(str(ort_optimized_model_path), load_external_data=True) return self.model_type_class_mapping[self.model_type](model) - def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True): + def optimize( + self, + input_fp32_onnx_path, + optimized_onnx_path, + float16=True, + keep_io_types=False, + fp32_op_list=None, + keep_outputs=None, + optimize_by_ort=True, + optimize_by_fusion=True, + final_target_float16=True, + ): """Optimize onnx model using ONNX Runtime transformers optimizer""" logger.info(f"Optimize {input_fp32_onnx_path}...") - fusion_options = FusionOptions(self.model_type) - if self.model_type in ["unet"] and not float16: - fusion_options.enable_packed_kv = False - fusion_options.enable_packed_qkv = False - - m = optimize_model( - input_fp32_onnx_path, - model_type=self.model_type, - num_heads=0, # will be deduced from graph - hidden_size=0, # will be deduced from graph - opt_level=0, - optimization_options=fusion_options, - use_gpu=True, - ) - - if self.model_type == "clip": - m.prune_graph(outputs=["text_embeddings"]) # remove the pooler_output, and only keep the first output. + + if optimize_by_fusion: + fusion_options = FusionOptions(self.model_type) + + # It is allowed float16=False and final_target_float16=True, for using fp32 as intermediate optimization step. + # For rare fp32 use case, we can disable packed kv/qkv since there is no fp32 TRT fused attention kernel. + if self.model_type in ["unet"] and not final_target_float16: + fusion_options.enable_packed_kv = False + fusion_options.enable_packed_qkv = False + + m = optimize_model( + input_fp32_onnx_path, + model_type=self.model_type, + num_heads=0, # will be deduced from graph + hidden_size=0, # will be deduced from graph + opt_level=0, + optimization_options=fusion_options, + use_gpu=True, + ) + else: + model = onnx.load_model(input_fp32_onnx_path, load_external_data=True) + m = self.model_type_class_mapping[self.model_type](model) + + if keep_outputs: + m.prune_graph(outputs=keep_outputs) + + use_external_data_format = m.model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF + + # Note that ORT < 1.16 could not save model larger than 2GB. + # This step is is optional since it has no impact on inference latency. + # The optimized model is not portable. It could only run in the same execution provider (CUDA EP in this case). + # When the model has been optimized by onnxruntime, we can disable optimization in SessionOption + # to save session creation time. Another benefit is to inspect the final graph for developing purpose. + from onnxruntime import __version__ as ort_version + + if optimize_by_ort and (version.parse(ort_version) >= version.parse("1.16.0") or not use_external_data_format): + m = self.optimize_by_ort(m, use_external_data_format=use_external_data_format) if float16: logger.info("Convert to float16 ...") m.convert_float_to_float16( - keep_io_types=False, - op_block_list=["RandomNormalLike"], + keep_io_types=keep_io_types, + op_block_list=fp32_op_list, ) - # Note that ORT 1.15 could not save model larger than 2GB. This only works for float16 - if float16 or (self.model_type != "unet"): - m = self.optimize_by_ort(m) - m.get_operator_statistics() m.get_fused_operator_statistics() - m.save_model_to_file(optimized_onnx_path, use_external_data_format=(self.model_type == "unet") and not float16) + m.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format) logger.info("%s is optimized: %s", self.model_type, optimized_onnx_path) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py index 7192e4ad55..0afa13a0f4 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py @@ -7,122 +7,36 @@ import logging import os import shutil -from collections import OrderedDict -from typing import Any, Dict +import sys +from typing import Union import torch import onnxruntime as ort -from onnxruntime.transformers.io_binding_helper import TypeHelper logger = logging.getLogger(__name__) -class OrtCudaSession: - """Inference Session with IO Binding for ONNX Runtime CUDA or TensorRT provider""" - - def __init__(self, ort_session: ort.InferenceSession, device: torch.device, enable_cuda_graph=False): - self.ort_session = ort_session - self.input_names = [input.name for input in self.ort_session.get_inputs()] - self.output_names = [output.name for output in self.ort_session.get_outputs()] - self.io_name_to_numpy_type = TypeHelper.get_io_numpy_type_map(self.ort_session) - self.io_binding = self.ort_session.io_binding() - self.enable_cuda_graph = enable_cuda_graph - - self.input_tensors = OrderedDict() - self.output_tensors = OrderedDict() - self.device = device - - def __del__(self): - del self.input_tensors - del self.output_tensors - del self.io_binding - del self.ort_session - - def allocate_buffers(self, shape_dict: Dict[str, tuple]): - """Allocate tensors for I/O Binding""" - if self.enable_cuda_graph: - for name, shape in shape_dict.items(): - if name in self.input_names: - # Reuse allocated buffer when the shape is same - if name in self.input_tensors: - if tuple(self.input_tensors[name].shape) == tuple(shape): - continue - raise RuntimeError("Expect static input shape for cuda graph") - - numpy_dtype = self.io_name_to_numpy_type[name] - tensor = torch.empty(tuple(shape), dtype=TypeHelper.numpy_type_to_torch_type(numpy_dtype)).to( - device=self.device - ) - self.input_tensors[name] = tensor - - self.io_binding.bind_input( - name, - tensor.device.type, - tensor.device.index, - numpy_dtype, - list(tensor.size()), - tensor.data_ptr(), - ) - - for name, shape in shape_dict.items(): - if name in self.output_names: - # Reuse allocated buffer when the shape is same - if name in self.output_tensors and tuple(self.output_tensors[name].shape) == tuple(shape): - continue - - numpy_dtype = self.io_name_to_numpy_type[name] - tensor = torch.empty(tuple(shape), dtype=TypeHelper.numpy_type_to_torch_type(numpy_dtype)).to( - device=self.device - ) - self.output_tensors[name] = tensor - - self.io_binding.bind_output( - name, - tensor.device.type, - tensor.device.index, - numpy_dtype, - list(tensor.size()), - tensor.data_ptr(), - ) - - def infer(self, feed_dict): - """Bind input tensors and run inference""" - for name, tensor in feed_dict.items(): - assert isinstance(tensor, torch.Tensor) and tensor.is_contiguous() - if name in self.input_names: - if self.enable_cuda_graph: - assert self.input_tensors[name].nelement() == tensor.nelement() - assert tensor.device.type == "cuda" - # Update input tensor inplace since cuda graph requires input and output has fixed memory address. - from cuda import cudart - - cudart.cudaMemcpy( - self.input_tensors[name].data_ptr(), - tensor.data_ptr(), - tensor.element_size() * tensor.nelement(), - cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice, - ) - else: - self.io_binding.bind_input( - name, - tensor.device.type, - tensor.device.index, - TypeHelper.torch_type_to_numpy_type(tensor.dtype), - [1] if len(tensor.shape) == 0 else list(tensor.shape), - tensor.data_ptr(), - ) +def add_transformers_dir_to_path(): + sys.path.append(os.path.dirname(__file__)) + + transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..")) + if transformers_dir not in sys.path: + sys.path.append(transformers_dir) - self.ort_session.run_with_iobinding(self.io_binding) - return self.output_tensors +add_transformers_dir_to_path() +from io_binding_helper import CudaSession # noqa: E402. Walk-around to test locally -class Engine(OrtCudaSession): +# ----------------------------------------------------------------------------------------------------- +# Utilities for CUDA EP +# ----------------------------------------------------------------------------------------------------- +class Engine(CudaSession): def __init__(self, engine_path, provider: str, device_id: int = 0, enable_cuda_graph=False): self.engine_path = engine_path self.provider = provider - self.provider_options = self.get_cuda_provider_options(device_id, enable_cuda_graph) + self.provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph) device = torch.device("cuda", device_id) ort_session = ort.InferenceSession( @@ -135,13 +49,6 @@ def __init__(self, engine_path, provider: str, device_id: int = 0, enable_cuda_g super().__init__(ort_session, device, enable_cuda_graph) - def get_cuda_provider_options(self, device_id: int, enable_cuda_graph: bool) -> Dict[str, Any]: - return { - "device_id": device_id, - "arena_extend_strategy": "kSameAsRequested", - "enable_cuda_graph": enable_cuda_graph, - } - class Engines: def __init__(self, provider, onnx_opset: int = 14): @@ -197,9 +104,16 @@ def build( model = model_obj.get_model().to(model_obj.device) with torch.inference_mode(): inputs = model_obj.get_sample_input(1, 512, 512) + fp32_inputs = tuple( + [ + (tensor.to(torch.float32) if tensor.dtype == torch.float16 else tensor) + for tensor in inputs + ] + ) + torch.onnx.export( model, - inputs, + fp32_inputs, onnx_path, export_params=True, opset_version=self.onnx_opset, @@ -224,3 +138,125 @@ def build( def get_engine(self, model_name): return self.engines[model_name] + + +def run_engine(engine, feed_dict): + return engine.infer(feed_dict) + + +# ----------------------------------------------------------------------------------------------------- +# Utilities for both CUDA and TensorRT EP +# ----------------------------------------------------------------------------------------------------- + + +class StableDiffusionPipelineMixin: + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def encode_prompt(self, clip_engine, prompt, negative_prompt): + """ + Encodes the prompt into text encoder hidden states. + """ + + # Tokenize prompt + text_input_ids = ( + self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.torch_device) + ) + + # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt + text_embeddings = run_engine(clip_engine, {"input_ids": text_input_ids})["text_embeddings"].clone() + + # Tokenize negative prompt + uncond_input_ids = ( + self.tokenizer( + negative_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.torch_device) + ) + + uncond_embeddings = run_engine(clip_engine, {"input_ids": uncond_input_ids})["text_embeddings"] + + # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) + + return text_embeddings + + def denoise_latent( + self, + unet_engine, + latents, + text_embeddings, + timesteps=None, + mask=None, + masked_image_latents=None, + timestep_fp16=False, + ): + if not isinstance(timesteps, torch.Tensor): + timesteps = self.scheduler.timesteps + + for _step_index, timestep in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) + if isinstance(mask, torch.Tensor): + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # Predict the noise residual + timestep_float = timestep.to(torch.float16) if timestep_fp16 else timestep.to(torch.float32) + + noise_pred = run_engine( + unet_engine, + {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings}, + )["latent"] + + # Perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample + + latents = 1.0 / 0.18215 * latents + return latents + + def decode_latent(self, vae_engine, latents): + images = run_engine(vae_engine, {"latent": latents})["images"] + images = (images / 2 + 0.5).clamp(0, 1) + return images.cpu().permute(0, 2, 3, 1).float().numpy() + + def set_cached_folder(self, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): + from diffusers.utils import DIFFUSERS_CACHE + from huggingface_hub import snapshot_download + + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + + self.cached_folder = ( + pretrained_model_name_or_path + if os.path.isdir(pretrained_model_name_or_path) + else snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py new file mode 100644 index 0000000000..faa3f8bfaa --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py @@ -0,0 +1,233 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import time + +import torch +from diffusion_models import PipelineInfo +from pipeline_stable_diffusion import StableDiffusionPipeline + + +class Img2ImgXLPipeline(StableDiffusionPipeline): + """ + Stable Diffusion Img2Img XL pipeline. + """ + + def __init__(self, pipeline_info: PipelineInfo, *args, **kwargs): + """ + Initializes the Img2Img XL Diffusion pipeline. + + Args: + pipeline_info (PipelineInfo): + Version and Type of stable diffusion pipeline. + """ + assert pipeline_info.is_xl_refiner() + + super().__init__(pipeline_info, *args, **kwargs) + + self.requires_aesthetics_score = True + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype + ): + if self.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0).to(device=self.device) + return add_time_ids + + def _infer( + self, + prompt, + negative_prompt, + init_image, + image_height, + image_width, + denoising_steps=30, + guidance=5.0, + seed=None, + warmup=False, + return_type="image", + ): + assert negative_prompt is None or len(prompt) == len(negative_prompt) + + original_size = (image_height, image_width) + crops_coords_top_left = (0, 0) + target_size = (image_height, image_width) + + strength = 0.3 + aesthetic_score = 6.0 + negative_aesthetic_score = 2.5 + + self.set_denoising_steps(denoising_steps) + self.set_random_seed(seed) + + with torch.inference_mode(), torch.autocast("cuda"): + batch_size = len(prompt) + + torch.cuda.synchronize() + e2e_tic = time.perf_counter() + + # Initialize timesteps + timesteps, t_start = self.initialize_timesteps(self.denoising_steps, strength) + + latent_timestep = timesteps[:1].repeat(batch_size) + + # CLIP text encoder 2 + text_embeddings, pooled_embeddings2 = self.encode_prompt( + prompt, + negative_prompt, + encoder="clip2", + tokenizer=self.tokenizer2, + pooled_outputs=True, + output_hidden_states=True, + ) + + # Time embeddings + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype=text_embeddings.dtype, + ) + + add_time_ids = add_time_ids.repeat(batch_size, 1) + + add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids} + + # Pre-process input image + init_image = self.preprocess_images(batch_size, (init_image,))[0] + + # VAE encode init image + if init_image.shape[1] == 4: + init_latents = init_image + else: + init_latents = self.encode_image(init_image) + + # Add noise to latents using timesteps + noise = torch.randn(init_latents.shape, device=self.device, dtype=torch.float32, generator=self.generator) + latents = self.scheduler.add_noise(init_latents, noise, t_start, latent_timestep) + + # UNet denoiser + latents = self.denoise_latent( + latents, + text_embeddings, + timesteps=timesteps, + step_offset=t_start, + denoiser="unetxl", + guidance=guidance, + add_kwargs=add_kwargs, + ) + + with torch.inference_mode(): + # VAE decode latent + if return_type == "latent": + images = latents + else: + images = self.decode_latent(latents / self.vae_scaling_factor) + + torch.cuda.synchronize() + e2e_toc = time.perf_counter() + + if not warmup: + print("SD-XL Refiner Pipeline") + self.print_summary(e2e_tic, e2e_toc, batch_size) + self.save_images(images, "img2img-xl", prompt) + + return images, (e2e_toc - e2e_tic) * 1000.0 + + def run( + self, + prompt, + negative_prompt, + init_image, + image_height, + image_width, + denoising_steps=30, + guidance=5.0, + seed=None, + warmup=False, + return_type="image", + ): + """ + Run the diffusion pipeline. + + Args: + prompt (str): + The text prompt to guide image generation. + negative_prompt (str): + The prompt not to guide the image generation. + init_image (tuple[torch.Tensor]): + Image from base pipeline. + image_height (int): + Height (in pixels) of the image to be generated. Must be a multiple of 8. + image_width (int): + Width (in pixels) of the image to be generated. Must be a multiple of 8. + denoising_steps (int): + Number of denoising steps. More steps usually lead to higher quality image at the expense of slower inference. + guidance (float): + Higher guidance scale encourages to generate images that are closely linked to the text prompt. + seed (int): + Seed for the random generator + warmup (bool): + Indicate if this is a warmup run. + return_type (str): + It can be "latent" or "image". + """ + + if self.is_backend_tensorrt(): + import tensorrt as trt + from trt_utilities import TRT_LOGGER + + with trt.Runtime(TRT_LOGGER): + return self._infer( + prompt, + negative_prompt, + init_image, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) + else: + return self._infer( + prompt, + negative_prompt, + init_image, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py new file mode 100644 index 0000000000..e28db2b771 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py @@ -0,0 +1,454 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import os +import pathlib +import random + +import nvtx +import torch +from cuda import cudart +from diffusion_models import PipelineInfo, get_tokenizer +from diffusion_schedulers import DDIMScheduler, EulerAncestralDiscreteScheduler, UniPCMultistepScheduler +from engine_builder import EngineType +from engine_builder_ort_cuda import OrtCudaEngineBuilder +from engine_builder_ort_trt import OrtTensorrtEngineBuilder +from engine_builder_tensorrt import TensorrtEngineBuilder + + +class StableDiffusionPipeline: + """ + Stable Diffusion pipeline using TensorRT. + """ + + def __init__( + self, + pipeline_info: PipelineInfo, + max_batch_size=16, + scheduler="DDIM", + device="cuda", + output_dir=".", + hf_token=None, + verbose=False, + nvtx_profile=False, + use_cuda_graph=False, + framework_model_dir="pytorch_model", + engine_type: EngineType = EngineType.ORT_TRT, + ): + """ + Initializes the Diffusion pipeline. + + Args: + pipeline_info (PipelineInfo): + Version and Type of pipeline. + max_batch_size (int): + Maximum batch size for dynamic batch engine. + scheduler (str): + The scheduler to guide the denoising process. Must be one of [DDIM, EulerA, UniPC]. + device (str): + PyTorch device to run inference. Default: 'cuda' + output_dir (str): + Output directory for log files and image artifacts + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + verbose (bool): + Enable verbose logging. + nvtx_profile (bool): + Insert NVTX profiling markers. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + framework_model_dir (str): + cache directory for framework checkpoints + engine_type (EngineType) + backend engine type like ORT_TRT or TRT + """ + + self.pipeline_info = pipeline_info + self.version = pipeline_info.version + + self.vae_scaling_factor = pipeline_info.vae_scaling_factor() + + self.max_batch_size = max_batch_size + + self.framework_model_dir = framework_model_dir + self.output_dir = output_dir + for directory in [self.framework_model_dir, self.output_dir]: + if not os.path.exists(directory): + print(f"[I] Create directory: {directory}") + pathlib.Path(directory).mkdir(parents=True) + + self.hf_token = hf_token + self.device = device + self.torch_device = torch.device(device, torch.cuda.current_device()) + self.verbose = verbose + self.nvtx_profile = nvtx_profile + + # Scheduler options + sched_opts = {"num_train_timesteps": 1000, "beta_start": 0.00085, "beta_end": 0.012} + if self.version in ("2.0", "2.1"): + sched_opts["prediction_type"] = "v_prediction" + else: + sched_opts["prediction_type"] = "epsilon" + + if scheduler == "DDIM": + self.scheduler = DDIMScheduler(device=self.device, **sched_opts) + elif scheduler == "EulerA": + self.scheduler = EulerAncestralDiscreteScheduler(device=self.device, **sched_opts) + elif scheduler == "UniPC": + self.scheduler = UniPCMultistepScheduler(device=self.device) + else: + raise ValueError("Scheduler should be either DDIM, EulerA or UniPC") + + self.stages = pipeline_info.stages() + + self.vae_torch_fallback = self.pipeline_info.is_xl() + + self.use_cuda_graph = use_cuda_graph + + self.tokenizer = None + self.tokenizer2 = None + + self.generator = None + self.denoising_steps = None + self.actual_steps = None + + # backend engine + self.engine_type = engine_type + if engine_type == EngineType.TRT: + self.backend = TensorrtEngineBuilder(pipeline_info, max_batch_size, hf_token, device, use_cuda_graph) + elif engine_type == EngineType.ORT_TRT: + self.backend = OrtTensorrtEngineBuilder(pipeline_info, max_batch_size, hf_token, device, use_cuda_graph) + elif engine_type == EngineType.ORT_CUDA: + self.backend = OrtCudaEngineBuilder(pipeline_info, max_batch_size, hf_token, device, use_cuda_graph) + else: + raise RuntimeError(f"Backend engine type {engine_type.name} is not supported") + + # Load text tokenizer + if not self.pipeline_info.is_xl_refiner(): + self.tokenizer = get_tokenizer( + self.pipeline_info, self.framework_model_dir, self.hf_token, subfolder="tokenizer" + ) + + if self.pipeline_info.is_xl(): + self.tokenizer2 = get_tokenizer( + self.pipeline_info, self.framework_model_dir, self.hf_token, subfolder="tokenizer_2" + ) + + # Create CUDA events + self.events = {} + for stage in ["clip", "denoise", "vae", "vae_encoder"]: + for marker in ["start", "stop"]: + self.events[stage + "-" + marker] = cudart.cudaEventCreate()[1] + + def is_backend_tensorrt(self): + return self.engine_type == EngineType.TRT + + def set_denoising_steps(self, denoising_steps: int): + if self.denoising_steps != denoising_steps: + assert self.denoising_steps is None # TODO(tianleiwu): support changing steps in different runs + # Pre-compute latent input scales and linear multistep coefficients + self.scheduler.set_timesteps(denoising_steps) + self.scheduler.configure() + self.denoising_steps = denoising_steps + + def load_resources(self, image_height, image_width, batch_size): + # If engine is built with static input shape, call this only once after engine build. + # Otherwise, it need be called before every inference run. + self.backend.load_resources(image_height, image_width, batch_size) + + def set_random_seed(self, seed): + # Initialize noise generator. Usually, it is done before a batch of inference. + self.generator = torch.Generator(device="cuda").manual_seed(seed) if isinstance(seed, int) else None + + def teardown(self): + for e in self.events.values(): + cudart.cudaEventDestroy(e) + + if self.backend: + self.backend.teardown() + + def run_engine(self, model_name, feed_dict): + return self.backend.run_engine(model_name, feed_dict) + + def initialize_latents(self, batch_size, unet_channels, latent_height, latent_width): + latents_dtype = torch.float32 # text_embeddings.dtype + latents_shape = (batch_size, unet_channels, latent_height, latent_width) + latents = torch.randn(latents_shape, device=self.device, dtype=latents_dtype, generator=self.generator) + # Scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def initialize_timesteps(self, timesteps, strength): + self.scheduler.set_timesteps(timesteps) + offset = self.scheduler.steps_offset if hasattr(self.scheduler, "steps_offset") else 0 + init_timestep = int(timesteps * strength) + offset + init_timestep = min(init_timestep, timesteps) + t_start = max(timesteps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(self.device) + return timesteps, t_start + + def preprocess_images(self, batch_size, images=()): + if self.nvtx_profile: + nvtx_image_preprocess = nvtx.start_range(message="image_preprocess", color="pink") + init_images = [] + for i in images: + image = i.to(self.device).float() + if image.shape[0] != batch_size: + image = image.repeat(batch_size, 1, 1, 1) + init_images.append(image) + if self.nvtx_profile: + nvtx.end_range(nvtx_image_preprocess) + return tuple(init_images) + + def encode_prompt( + self, + prompt, + negative_prompt, + encoder="clip", + tokenizer=None, + pooled_outputs=False, + output_hidden_states=False, + force_zeros_for_empty_prompt=False, + ): + if tokenizer is None: + tokenizer = self.tokenizer + + if self.nvtx_profile: + nvtx_clip = nvtx.start_range(message="clip", color="green") + cudart.cudaEventRecord(self.events["clip-start"], 0) + + # Tokenize prompt + text_input_ids = ( + tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.device) + ) + + # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt + outputs = self.run_engine(encoder, {"input_ids": text_input_ids}) + text_embeddings = outputs["text_embeddings"].clone() + if output_hidden_states: + hidden_states = outputs["hidden_states"].clone() + + # Note: negative prompt embedding is not needed for SD XL when guidance < 1 + + # For SD XL base, handle force_zeros_for_empty_prompt + is_empty_negative_prompt = all([not i for i in negative_prompt]) + if force_zeros_for_empty_prompt and is_empty_negative_prompt: + uncond_embeddings = torch.zeros_like(text_embeddings) + if output_hidden_states: + uncond_hidden_states = torch.zeros_like(hidden_states) + else: + # Tokenize negative prompt + uncond_input_ids = ( + tokenizer( + negative_prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.device) + ) + + outputs = self.run_engine(encoder, {"input_ids": uncond_input_ids}) + uncond_embeddings = outputs["text_embeddings"] + if output_hidden_states: + uncond_hidden_states = outputs["hidden_states"] + + # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) + + if pooled_outputs: + pooled_output = text_embeddings + + if output_hidden_states: + text_embeddings = torch.cat([uncond_hidden_states, hidden_states]).to(dtype=torch.float16) + + cudart.cudaEventRecord(self.events["clip-stop"], 0) + if self.nvtx_profile: + nvtx.end_range(nvtx_clip) + + if pooled_outputs: + return text_embeddings, pooled_output + return text_embeddings + + def denoise_latent( + self, + latents, + text_embeddings, + denoiser="unet", + timesteps=None, + step_offset=0, + mask=None, + masked_image_latents=None, + guidance=7.5, + add_kwargs=None, + ): + assert guidance > 1.0, "Guidance has to be > 1.0" # TODO: remove this constraint + + cudart.cudaEventRecord(self.events["denoise-start"], 0) + if not isinstance(timesteps, torch.Tensor): + timesteps = self.scheduler.timesteps + + for step_index, timestep in enumerate(timesteps): + if self.nvtx_profile: + nvtx_latent_scale = nvtx.start_range(message="latent_scale", color="pink") + + # Expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) + + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, step_offset + step_index, timestep + ) + + if isinstance(mask, torch.Tensor): + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + if self.nvtx_profile: + nvtx.end_range(nvtx_latent_scale) + + # Predict the noise residual + if self.nvtx_profile: + nvtx_unet = nvtx.start_range(message="unet", color="blue") + + timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep + + params = { + "sample": latent_model_input, + "timestep": timestep_float, + "encoder_hidden_states": text_embeddings, + } + if add_kwargs: + params.update(add_kwargs) + + noise_pred = self.run_engine(denoiser, params)["latent"] + + if self.nvtx_profile: + nvtx.end_range(nvtx_unet) + + if self.nvtx_profile: + nvtx_latent_step = nvtx.start_range(message="latent_step", color="pink") + + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond) + + if type(self.scheduler) == UniPCMultistepScheduler: + latents = self.scheduler.step(noise_pred, timestep, latents, return_dict=False)[0] + else: + latents = self.scheduler.step(noise_pred, latents, step_offset + step_index, timestep) + + if self.nvtx_profile: + nvtx.end_range(nvtx_latent_step) + + cudart.cudaEventRecord(self.events["denoise-stop"], 0) + + # The actual number of steps. It might be different from denoising_steps. + self.actual_steps = len(timesteps) + + return latents + + def encode_image(self, init_image): + if self.nvtx_profile: + nvtx_vae = nvtx.start_range(message="vae_encoder", color="red") + cudart.cudaEventRecord(self.events["vae_encoder-start"], 0) + init_latents = self.run_engine("vae_encoder", {"images": init_image})["latent"] + cudart.cudaEventRecord(self.events["vae_encoder-stop"], 0) + if self.nvtx_profile: + nvtx.end_range(nvtx_vae) + + init_latents = self.vae_scaling_factor * init_latents + return init_latents + + def decode_latent(self, latents): + if self.nvtx_profile: + nvtx_vae = nvtx.start_range(message="vae", color="red") + cudart.cudaEventRecord(self.events["vae-start"], 0) + images = self.backend.vae_decode(latents) + cudart.cudaEventRecord(self.events["vae-stop"], 0) + if self.nvtx_profile: + nvtx.end_range(nvtx_vae) + return images + + def print_summary(self, tic, toc, batch_size, vae_enc=False): + print("|------------|--------------|") + print("| {:^10} | {:^12} |".format("Module", "Latency")) + print("|------------|--------------|") + if vae_enc: + print( + "| {:^10} | {:>9.2f} ms |".format( + "VAE-Enc", + cudart.cudaEventElapsedTime(self.events["vae_encoder-start"], self.events["vae_encoder-stop"])[1], + ) + ) + print( + "| {:^10} | {:>9.2f} ms |".format( + "CLIP", cudart.cudaEventElapsedTime(self.events["clip-start"], self.events["clip-stop"])[1] + ) + ) + print( + "| {:^10} | {:>9.2f} ms |".format( + "UNet x " + str(self.actual_steps), + cudart.cudaEventElapsedTime(self.events["denoise-start"], self.events["denoise-stop"])[1], + ) + ) + print( + "| {:^10} | {:>9.2f} ms |".format( + "VAE-Dec", cudart.cudaEventElapsedTime(self.events["vae-start"], self.events["vae-stop"])[1] + ) + ) + + print("|------------|--------------|") + print("| {:^10} | {:>9.2f} ms |".format("Pipeline", (toc - tic) * 1000.0)) + print("|------------|--------------|") + print(f"Throughput: {batch_size / (toc - tic):.2f} image/s") + + @staticmethod + def to_pil_image(images): + images = ( + ((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy() + ) + + from PIL import Image + + return [Image.fromarray(images[i]) for i in range(images.shape[0])] + + def save_images(self, images, pipeline, prompt): + image_name_prefix = ( + pipeline + "".join(set(["-" + prompt[i].replace(" ", "_")[:10] for i in range(len(prompt))])) + "-" + ) + + images = self.to_pil_image(images) + random_session_id = str(random.randint(1000, 9999)) + for i, image in enumerate(images): + image_path = os.path.join( + self.output_dir, image_name_prefix + str(i + 1) + "-" + random_session_id + ".png" + ) + print(f"Saving image {i+1} / {len(images)} to: {image_path}") + image.save(image_path) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py new file mode 100644 index 0000000000..b9759b44e7 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py @@ -0,0 +1,155 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import time + +import torch +from diffusion_models import PipelineInfo +from pipeline_stable_diffusion import StableDiffusionPipeline + + +class Txt2ImgPipeline(StableDiffusionPipeline): + """ + Stable Diffusion Txt2Img pipeline using NVidia TensorRT. + """ + + def __init__(self, pipeline_info: PipelineInfo, **kwargs): + """ + Initializes the Txt2Img Diffusion pipeline. + + Args: + pipeline_info (PipelineInfo): + Version and Type of stable diffusion pipeline. + """ + super().__init__(pipeline_info, **kwargs) + + def _infer( + self, + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=50, + guidance=7.5, + seed=None, + warmup=False, + return_type="latent", + ): + assert len(prompt) == len(negative_prompt) + batch_size = len(prompt) + + self.set_denoising_steps(denoising_steps) + self.set_random_seed(seed) + + with torch.inference_mode(), torch.autocast("cuda"): + # Pre-initialize latents + latents = self.initialize_latents( + batch_size=batch_size, + unet_channels=4, + latent_height=(image_height // 8), + latent_width=(image_width // 8), + ) + + torch.cuda.synchronize() + e2e_tic = time.perf_counter() + + # CLIP text encoder + text_embeddings = self.encode_prompt(prompt, negative_prompt) + + # UNet denoiser + latents = self.denoise_latent(latents, text_embeddings, guidance=guidance) + + # VAE decode latent + images = self.decode_latent(latents / self.vae_scaling_factor) + + torch.cuda.synchronize() + e2e_toc = time.perf_counter() + + if not warmup: + self.print_summary(e2e_tic, e2e_toc, batch_size) + self.save_images(images, "txt2img", prompt) + + return images, (e2e_toc - e2e_tic) * 1000.0 + + def run( + self, + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=30, + guidance=7.5, + seed=None, + warmup=False, + return_type="image", + ): + """ + Run the diffusion pipeline. + + Args: + prompt (str): + The text prompt to guide image generation. + negative_prompt (str): + The prompt not to guide the image generation. + image_height (int): + Height (in pixels) of the image to be generated. Must be a multiple of 8. + image_width (int): + Width (in pixels) of the image to be generated. Must be a multiple of 8. + denoising_steps (int): + Number of denoising steps. More steps usually lead to higher quality image at the expense of slower inference. + guidance (float): + Higher guidance scale encourages to generate images that are closely linked to the text prompt. + seed (int): + Seed for the random generator + warmup (bool): + Indicate if this is a warmup run. + return_type (str): + type of return. The value can be "latent" or "image". + """ + if self.is_backend_tensorrt(): + import tensorrt as trt + from trt_utilities import TRT_LOGGER + + with trt.Runtime(TRT_LOGGER): + return self._infer( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) + else: + return self._infer( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py new file mode 100644 index 0000000000..1b3be143e6 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py @@ -0,0 +1,207 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import time + +import torch +from diffusion_models import PipelineInfo +from pipeline_stable_diffusion import StableDiffusionPipeline + + +class Txt2ImgXLPipeline(StableDiffusionPipeline): + """ + Stable Diffusion Txt2Img XL pipeline. + """ + + def __init__(self, pipeline_info: PipelineInfo, *args, **kwargs): + """ + Initializes the Txt2Img XL Diffusion pipeline. + + Args: + pipeline_info (PipelineInfo): + Version and Type of stable diffusion pipeline. + """ + assert pipeline_info.is_xl_base() + + super().__init__(pipeline_info, *args, **kwargs) + + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + def _infer( + self, + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=30, + guidance=5.0, + seed=None, + warmup=False, + return_type="image", + ): + assert len(prompt) == len(negative_prompt) + + original_size = (image_height, image_width) + crops_coords_top_left = (0, 0) + target_size = (image_height, image_width) + batch_size = len(prompt) + + self.set_denoising_steps(denoising_steps) + self.set_random_seed(seed) + + with torch.inference_mode(), torch.autocast("cuda"): + # Pre-initialize latents + latents = self.initialize_latents( + batch_size=batch_size, + unet_channels=4, + latent_height=(image_height // 8), + latent_width=(image_width // 8), + ) + + torch.cuda.synchronize() + e2e_tic = time.perf_counter() + + # CLIP text encoder + text_embeddings = self.encode_prompt( + prompt, + negative_prompt, + encoder="clip", + tokenizer=self.tokenizer, + output_hidden_states=True, + force_zeros_for_empty_prompt=True, + ) + # CLIP text encoder 2 + text_embeddings2, pooled_embeddings2 = self.encode_prompt( + prompt, + negative_prompt, + encoder="clip2", + tokenizer=self.tokenizer2, + pooled_outputs=True, + output_hidden_states=True, + force_zeros_for_empty_prompt=True, + ) + + # Merged text embeddings + text_embeddings = torch.cat([text_embeddings, text_embeddings2], dim=-1) + + # Time embeddings + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=text_embeddings.dtype + ) + add_time_ids = add_time_ids.repeat(batch_size, 1) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0).to(self.device) + + add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids} + + # UNet denoiser + latents = self.denoise_latent( + latents, + text_embeddings, + denoiser="unetxl", + guidance=guidance, + add_kwargs=add_kwargs, + ) + + # VAE decode latent + if return_type == "latent": + images = latents + else: + images = self.decode_latent(latents / self.vae_scaling_factor) + + torch.cuda.synchronize() + e2e_toc = time.perf_counter() + + if not warmup: + print("SD-XL Base Pipeline") + self.print_summary(e2e_tic, e2e_toc, batch_size) + if return_type != "latent": + self.save_images(images, "txt2img-xl", prompt) + + return images, (e2e_toc - e2e_tic) * 1000.0 + + def run( + self, + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=30, + guidance=5.0, + seed=None, + warmup=False, + return_type="image", + ): + """ + Run the diffusion pipeline. + + Args: + prompt (str): + The text prompt to guide image generation. + negative_prompt (str): + The prompt not to guide the image generation. + image_height (int): + Height (in pixels) of the image to be generated. Must be a multiple of 8. + image_width (int): + Width (in pixels) of the image to be generated. Must be a multiple of 8. + denoising_steps (int): + Number of denoising steps. More steps usually lead to higher quality image at the expense of slower inference. + guidance (float): + Higher guidance scale encourages to generate images that are closely linked to the text prompt. + seed (int): + Seed for the random generator + warmup (bool): + Indicate if this is a warmup run. + return_type (str): + It can be "latent" or "image". + """ + + if self.is_backend_tensorrt(): + import tensorrt as trt + from trt_utilities import TRT_LOGGER + + with trt.Runtime(TRT_LOGGER): + return self._infer( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) + else: + return self._infer( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt deleted file mode 100644 index b942749f8d..0000000000 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt +++ /dev/null @@ -1,8 +0,0 @@ --r requirements.txt -onnxruntime-gpu>=1.14 -py3nvml>=0.2.7 -# cuda-python is needed for cuda graph. It shall be compatible with CUDA version of torch and onnxruntime-gpu. -cuda-python==11.7.0 -#To export onnx of stable diffusion, please install PyTorch 1.13.1+cu117 -#--extra-index-url https://download.pytorch.org/whl/cu117 -#torch==1.13.1+cu117 diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda11.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda11.txt new file mode 100644 index 0000000000..5f908c4f5f --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda11.txt @@ -0,0 +1,21 @@ +-r requirements.txt + +# Official onnxruntime-gpu 1.16.1 is built with CUDA 11.8. +onnxruntime-gpu>=1.16.1 + +py3nvml + +# The version of cuda-python shall be compatible with installed CUDA version. +# For example, if your CUDA version is 12.1, you can install cuda-python 12.1. +cuda-python==11.8.0 + +# For windows, cuda-python need the following +pywin32; platform_system == "Windows" + +nvtx + +# Please install PyTorch 2.1 or above for CUDA 11.8 using one of the following commands: +# pip3 install torch --index-url https://download.pytorch.org/whl/cu118 + +# Run the following command to install some extra packages for onnx graph optimization for TensorRT manually. +# pip3 install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda12.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda12.txt new file mode 100644 index 0000000000..e4e765831c --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda12.txt @@ -0,0 +1,21 @@ +-r requirements.txt + +# For CUDA 12.*, you will need build onnxruntime-gpu from source and install the wheel. See README.md for detail. +# onnxruntime-gpu>=1.16.1 + +py3nvml + +# The version of cuda-python shall be compatible with installed CUDA version. +# For example, if your CUDA version is 12.1, you can install cuda-python 12.1. +cuda-python==12.1.0 + +# For windows, cuda-python need the following +pywin32; platform_system == "Windows" + +nvtx + +# Please install PyTorch 2.1 or above for 12.1 using one of the following commands: +# pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + +# Run the following command to install some extra packages for onnx graph optimization for TensorRT manually. +# pip3 install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt deleted file mode 100644 index 567f39c011..0000000000 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt +++ /dev/null @@ -1,18 +0,0 @@ -diffusers>=0.16.0 -transformers>=4.26.0 -numpy>=1.24.1 -accelerate -onnx>=1.13.0 -coloredlogs -packaging -protobuf -psutil -sympy -tensorrt>=8.6.1 -onnxruntime-gpu>=1.15.1 -py3nvml -# cuda-python version shall be compatible with CUDA version of torch and onnxruntime-gpu -cuda-python==11.7.0 -#To export onnx of stable diffusion, please install PyTorch 1.13.1+cu117 -#--extra-index-url https://download.pytorch.org/whl/cu117 -#torch==1.13.1+cu117 diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt index 68947a1618..9386a941fb 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt @@ -1,10 +1,15 @@ -diffusers>=0.15.1 -transformers>=4.26.0 +diffusers>=0.19.3 +transformers>=4.31.0 numpy>=1.24.1 accelerate onnx>=1.13.0 coloredlogs packaging +# Use newer version of protobuf might cause crash protobuf==3.20.3 psutil sympy +# The following are for SDXL +optimum>=1.11.1 +safetensors +invisible_watermark diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/trt_utilities.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/trt_utilities.py new file mode 100644 index 0000000000..d03a9f9f55 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/trt_utilities.py @@ -0,0 +1,12 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import tensorrt as trt + +TRT_LOGGER = trt.Logger(trt.Logger.ERROR) + + +def init_trt_plugins(): + # Register TensorRT plugins + trt.init_libnvinfer_plugins(TRT_LOGGER, "") diff --git a/onnxruntime/python/tools/transformers/models/whisper/README.md b/onnxruntime/python/tools/transformers/models/whisper/README.md index 5572900598..8ff5c8a6e1 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/README.md +++ b/onnxruntime/python/tools/transformers/models/whisper/README.md @@ -2,18 +2,37 @@ ## Exporting Whisper with Beam Search -There are two ways to export Whisper with beam search (using Whisper tiny as an example). +There are several ways to export Whisper with beam search (using Whisper tiny as an example). + +### Option 1: from convert_to_onnx -Option 1: from source ``` +# From source $ git clone https://github.com/microsoft/onnxruntime -$ cd onnxruntime/onnxruntime/python/tools/transformers/models/whisper -$ python3 convert_to_onnx.py -m openai/whisper-tiny --output whispertiny --use_external_data_format +$ cd onnxruntime/onnxruntime/python/tools/transformers/ +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format + +# From wheel +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format ``` -Option 2: from wheel +### Option 2: end-to-end model from [Olive](https://github.com/microsoft/Olive/tree/main/examples/whisper) + +Please follow the [README instructions](https://github.com/microsoft/Olive/tree/main/examples/whisper#prerequisites) in Olive. + +### Option 3: from [Hugging Face Optimum](https://github.com/huggingface/optimum) + +Run the following Python code to export: + ``` -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format +from optimum.onnxruntime import ORTModelForSpeechSeq2Seq + +model_name = "openai/whisper-large-v2" +model = ORTModelForSpeechSeq2Seq.from_pretrained( + model_name, + export=True, +) +model.save_pretrained(model_name.split("/")[-1] + "-onnx") ``` ## Exporting + Optimizing + Quantizing Whisper with Beam Search @@ -23,7 +42,7 @@ Here are some additional examples for exporting Whisper with beam search. Export with Forced Decoder Input Ids ``` # From source: -$ python3 convert_to_onnx.py -m openai/whisper-tiny --output whispertiny --use_external_data_format --use_forced_decoder_ids +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --use_forced_decoder_ids # From wheel: $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --use_forced_decoder_ids @@ -32,7 +51,7 @@ $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/w Export + Optimize for FP32 ``` # From source: -$ python3 convert_to_onnx.py -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp32 +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp32 # From wheel: $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp32 @@ -41,7 +60,7 @@ $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/w Export + Optimize for FP16 and GPU ``` # From source: -$ python3 convert_to_onnx.py -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda # From wheel: $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda @@ -50,8 +69,128 @@ $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/w Export + Quantize for INT8 ``` # From source: -$ python3 convert_to_onnx.py -m openai/whisper-tiny --output whispertiny --use_external_data_format --precision int8 --quantize_embedding_layer +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --precision int8 --quantize_embedding_layer # From wheel: $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --precision int8 --quantize_embedding_layer ``` + +## Benchmark Whisper + +Here are some examples of how you can benchmark Whisper across various end-to-end (E2E) implementations. + +### Variants + +1. PyTorch without `torch.compile`, FP32 +``` +python3 -m models.whisper.benchmark \ + --benchmark-type hf-pt-eager \ + --audio-path 1272-141231-0002.mp3 \ + --model-name openai/whisper-large-v2 \ + --precision fp32 \ + --device cpu +``` + +2. PyTorch with `torch.compile`, FP16 +``` +python3 -m models.whisper.benchmark \ + --benchmark-type hf-pt-compile \ + --audio-path 1272-141231-0002.mp3 \ + --model-name openai/whisper-large-v2 \ + --precision fp16 \ + --device cuda +``` + +3. Optimum + ONNX Runtime, FP32, export via Optimum +``` +python3 -m models.whisper.benchmark \ + --benchmark-type hf-ort \ + --audio-path 1272-141231-0002.mp3 \ + --model-name openai/whisper-large-v2 \ + --hf-ort-dir-path ./whisper-large-v2-onnx/ \ + --precision fp32 \ + --device cpu +``` + +4. ONNX Runtime, FP32, export via Olive or convert_to_onnx +``` +python3 -m models.whisper.benchmark \ + --benchmark-type ort \ + --audio-path 1272-141231-0002.mp3 \ + --model-name openai/whisper-large-v2 \ + --ort-model-path ./wlarge-fp32/whisper-large-v2_beamsearch.onnx \ + --precision fp32 \ + --device cpu +``` + +5. ONNX Runtime, FP16, export via Olive or convert_to_onnx +``` +python3 -m models.whisper.benchmark \ + --benchmark-type ort \ + --audio-path 1272-141231-0002.mp3 \ + --model-name openai/whisper-large-v2 \ + --ort-model-path ./wlarge-fp32/whisper-large_all.onnx \ + --precision fp16 \ + --device cuda +``` + +6. ONNX Runtime, INT8, export via Olive or convert_to_onnx +``` +python3 -m models.whisper.benchmark \ + --benchmark-type ort \ + --audio-path 1272-141231-0002.mp3 \ + --model-name openai/whisper-large-v2 \ + --ort-model-path ./wlarge-fp32/whisper-large-v2_all.onnx \ + --precision fp32 \ + --device cpu +``` + +You can profile a variant by adding the `--profile` flag. + +### Benchmark All + +You can use `benchmark_all.py` to benchmark across various platforms and automatically store the results in a CSV file. Here is an example. + +``` +python3 -m models.whisper.benchmark_all \ + --audio-path ./whisper-test-audios/ \ + --hf-pt-eager \ + --hf-pt-compile \ + --hf-ort-dir-path ./whisper-large-v2-onnx/ \ + --ort-model-path ./wlarge-fp32/whisper-large-v2_all.onnx \ + --model-name openai/whisper-large-v2 \ + --precision fp32 \ + --device cpu +``` + +### Benchmarking on NVIDIA A100 + +Here is a benchmark for an MP3 file with 20.7s of audio. + +#### FP16 + +| Engine | Size | Per-Token Latency | Real-Time Factor | +| --------------- | -------- | ----------------- | ---------------- | +| PyTorch eager | Tiny | 4.697 ms/token | 0.004697 | +| PyTorch compile | Tiny | 3.406 ms/token | 0.003406 | +| ONNX Runtime | Tiny | 0.746 ms/token | 0.000746 | +| PyTorch eager | Medium | 17.837 ms/token | 0.017387 | +| PyTorch compile | Medium | 18.124 ms/token | 0.018124 | +| ONNX Runtime | Medium | 3.894 ms/token | 0.003894 | +| PyTorch eager | Large v2 | 23.470 ms/token | 0.023470 | +| PyTorch compile | Large v2 | 23.146 ms/token | 0.023146 | +| ONNX Runtime | Large v2 | 6.262 ms/token | 0.006262 | + +#### FP32 + +| Engine | Size | Per-Token Latency | Real-Time Factor | +| --------------- | -------- | ----------------- | ---------------- | +| PyTorch eager | Tiny | 6.220 ms/token | 0.006220 | +| PyTorch compile | Tiny | 3.944 ms/token | 0.003944 | +| ONNX Runtime | Tiny | 1.545 ms/token | 0.001545 | +| PyTorch eager | Medium | 19.093 ms/token | 0.019093 | +| PyTorch compile | Medium | 20.459 ms/token | 0.020459 | +| ONNX Runtime | Medium | 9.440 ms/token | 0.009440 | +| PyTorch eager | Large v2 | 25.844 ms/token | 0.025844 | +| PyTorch compile | Large v2 | 26.397 ms/token | 0.026397 | +| ONNX Runtime | Large v2 | 7.492 ms/token | 0.007492 | diff --git a/onnxruntime/python/tools/transformers/models/whisper/__init__.py b/onnxruntime/python/tools/transformers/models/whisper/__init__.py index 815be385d7..e80f36a391 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/__init__.py +++ b/onnxruntime/python/tools/transformers/models/whisper/__init__.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -import os.path +import os import sys sys.path.append(os.path.dirname(__file__)) diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py new file mode 100644 index 0000000000..759ae6d14f --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py @@ -0,0 +1,593 @@ +import argparse +import ast +import datetime +import gc +import logging +import os +import sys +import time + +import numpy as np +import psutil +import torch +import whisper +from benchmark_helper import measure_memory, setup_logger +from onnxruntime_extensions import get_library_path +from optimum.onnxruntime import ORTModelForSpeechSeq2Seq +from torch.profiler import ProfilerActivity, profile, record_function +from tqdm import trange +from transformers import AutoModelForSpeechSeq2Seq, WhisperConfig, WhisperProcessor + +import onnxruntime as ort + +logger = logging.getLogger(__name__) + + +def get_inputs(args: argparse.Namespace): + if args.benchmark_type not in {"hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"}: + raise Exception("Unable to auto-detect inputs for provided model") + + def load_via_ffmpeg(): + audio = whisper.load_audio(args.audio_path) + audio = whisper.pad_or_trim(audio) + return audio + + def load_via_numpy(): + with open(args.audio_path, "rb") as f: + audio = np.asarray(list(f.read()), dtype=np.uint8) + audio = np.array([audio]) + return audio + + inputs = { + "max_length": args.max_length, + "min_length": args.min_length, + "num_beams": args.num_beams, + "num_return_sequences": args.num_return_sequences, + "length_penalty": args.length_penalty, + "repetition_penalty": args.repetition_penalty, + } + if args.benchmark_type == "ort": + # convert_to_onnx export or ONNX E2E solution created by Olive + for k, v in inputs.items(): + inputs[k] = np.array([v], dtype=np.float32 if "penalty" in k else np.int32) + if args.has_decoder_input_ids: + inputs["decoder_input_ids"] = np.array([args.decoder_input_ids], dtype=np.int32) + if args.has_logits_processor: + inputs["logits_processor"] = np.array([args.logits_processor], dtype=np.int32) + + # Measure time taken to load audio file + logger.info(f"Load audio: {args.audio_path}") + load_audio_fn = lambda onnx_e2e: load_via_numpy() if onnx_e2e else load_via_ffmpeg() # noqa: E731 + time_fn(args, load_audio_fn, args.has_audio_stream) + audio_data = load_audio_fn(args.has_audio_stream) + + if args.has_audio_stream: + # ONNX E2E solution created by Olive + inputs["audio_stream"] = audio_data + return inputs + + # Measure time taken to get input features + logger.info("Feature extraction: ") + return_type = "np" if args.benchmark_type == "ort" else "pt" + processor_fn = lambda audio: args.processor.feature_extractor( # noqa: E731 + [audio], return_tensors=return_type, sampling_rate=args.sampling_rate + ).input_features + time_fn(args, processor_fn, audio_data) + input_features = processor_fn(audio_data) + + if args.benchmark_type == "ort": + # convert_to_onnx export + inputs["input_features"] = input_features + return inputs + + inputs["inputs"] = input_features.to( + dtype=torch.float16 if args.use_fp16 else torch.float32, device=args.target_device + ) + inputs["no_repeat_ngram_size"] = args.no_repeat_ngram_size + inputs["early_stopping"] = True + inputs["use_cache"] = True + + if args.decoder_input_ids: + inputs["forced_decoder_ids"] = args.decoder_input_ids + + return inputs + + +def get_model(args: argparse.Namespace): + model, sess_options = None, None + start_time, end_time = None, None + + # There are multiple sources that the model could come from: + # 1) Benchmark Whisper from Hugging Face + # 2) Benchmark Whisper ONNX model from Optimum export (without pre/post processing) + # 3) Benchmark Whisper ONNX E2E model from Olive (with pre/post processing) + + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: + source = args.hf_pt_model_path if args.hf_pt_model_path else args.model_name + start_time = time.time() + model = AutoModelForSpeechSeq2Seq.from_pretrained( + source, + torch_dtype=torch.float16 if args.use_fp16 else torch.float32, + use_cache=True, + ).to(args.target_device) + end_time = time.time() + + if args.benchmark_type == "hf-pt-compile": + model = torch.compile(model) + + elif args.benchmark_type in {"hf-ort", "ort"}: + sess_options = ort.SessionOptions() + sess_options.enable_profiling = args.profile + sess_options.register_custom_ops_library(get_library_path()) + if args.verbose: + sess_options.log_verbosity_level = 1 + sess_options.log_severity_level = 1 + if args.tune: + ort.set_default_logger_severity(0) + ort.set_default_logger_verbosity(0) + + else: + raise Exception(f"Cannot recognize {args.benchmark_type}") + + if args.benchmark_type == "hf-ort": + # Optimum export + provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider + provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None + + start_time = time.time() + model = ORTModelForSpeechSeq2Seq.from_pretrained( + args.hf_ort_dir_path, + use_io_binding=(args.device != "cpu"), + provider=provider, + provider_options=provider_options, + session_options=sess_options, + ) + end_time = time.time() + + if args.benchmark_type == "ort": + # convert_to_onnx.py export + logger.info(f"Loading model from {args.ort_model_path}") + start_time = time.time() + model = ort.InferenceSession( + args.ort_model_path, + sess_options, + providers=[args.execution_provider], + ) + end_time = time.time() + + logger.info(f"Loaded model in {end_time - start_time} s") + + return model + + +def time_fn(args, fn, inputs): + warmup_inputs = inputs[0] if type(inputs) is tuple else inputs + benchmark_inputs = inputs[1] if type(inputs) is tuple else inputs + + # Warm up + warmup_range = ( + range(args.warmup_runs) + if args.benchmark_type == "ort" + else trange(args.warmup_runs, file=sys.stdout, desc="Warm up") + ) + + if args.verbose: + outputs = fn(warmup_inputs) + logger.info(outputs) + + for _ in warmup_range: + fn(warmup_inputs) + + # Benchmark + if args.device != "cpu": + torch.cuda.synchronize() + start_time = time.time() + + bench_range = ( + range(args.num_runs) + if args.benchmark_type == "ort" + else trange(args.num_runs, file=sys.stdout, desc="Benchmark") + ) + for _ in bench_range: + fn(benchmark_inputs) + + if args.device != "cpu": + torch.cuda.synchronize() + end_time = time.time() + + # Newline print after trange in order to print metrics on new lines without progress bar on same line + if args.benchmark_type != "ort": + logger.info("") + + batch_size = 1 + latency = (end_time - start_time) / args.num_runs + throughput = batch_size / latency + + logger.info(f"Latency: {latency} s") + logger.info(f"Throughput: {throughput} qps") + return + + +def profile_fn(args, fn, inputs, inputs_type): + # Filename prefix format: + # "--___" + prefix = f"{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}" + filename = None + + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: + # Profile PyTorch kernels + with profile( # noqa: SIM117 + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True + ) as prof: + with record_function("model_inference"): + fn(inputs) + prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows) + + filename = os.path.join(args.log_folder, f"{prefix}.log") + with open(filename, "w") as f: + f.write(prof_data) + + else: + # Profile ORT kernels + fn(inputs) + + # Set new log name for ORT profile log generated + filename = f"{prefix}.json" + + return filename + + +def measure_fn(args, fn, inputs): + # Measure CPU usage + pid = os.getpid() + process = psutil.Process(pid) + process.cpu_percent(interval=0.1) + + fn(inputs) + logger.info(f"CPU usage: {process.cpu_percent(interval=None)}%") + + # Measure memory usage + gc.collect() + torch.cuda.empty_cache() + measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs), monitor_type=args.monitor_type) + + # Flush output so memory usage is printed + sys.stdout.flush() + + +def run_hf_inference(args, inputs, model): + # Inference steps to measure + def get_pred_ids(inputs): + # Inference pass with predicted token ids generation + predicted_ids = model.generate(**inputs) + return predicted_ids + + def gen_and_dec(inputs): + # Inference pass with generation and decoding + predicted_ids = get_pred_ids(inputs) + transcription = [] + for _ in range(args.num_return_sequences): + transcription.append(args.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]) + return predicted_ids, transcription + + # Examples of other inference steps that can be measured: + # To use, uncomment the function and assign it to `generate_fn` + + # def get_logits(inputs): + # # Inference pass without decoding + # outputs = model(**inputs) + # return outputs + + generate_fn = gen_and_dec + + if args.benchmark_type == "hf-pt-compile": + # Run forward pass once with each set of inputs to process through Dynamo + generate_fn(inputs) + + if args.profile: + new_logname = profile_fn(args, generate_fn, inputs, "gen-and-dec") + if args.benchmark_type == "hf-ort": + # Rename log files per model component and turn profiling off to stop appending to log + new_prefix = new_logname[: -len(".json")] + + old_logname = model.encoder.session.end_profiling() + new_logname = new_prefix + "-encoder.json" + if os.path.isfile(old_logname): + logger.warning(f"Renaming {old_logname} to {new_logname}") + os.rename(old_logname, os.path.join(args.log_folder, new_logname)) + + old_logname = model.decoder.session.end_profiling() + new_logname = new_prefix + "-decoder.json" + if os.path.isfile(old_logname): + logger.warning(f"Renaming {old_logname} to {new_logname}") + os.rename(old_logname, os.path.join(args.log_folder, new_logname)) + + old_logname = model.decoder_with_past.session.end_profiling() + new_logname = new_prefix + "-decoder-with-past.json" + if os.path.isfile(old_logname): + logger.warning(f"Renaming {old_logname} to {new_logname}") + os.rename(old_logname, os.path.join(args.log_folder, new_logname)) + + return + + # PyTorch evaluations + logger.info("\nEvaluating PyTorch...") + time_fn(args, generate_fn, inputs) + predicted_ids, transcription = generate_fn(inputs) + logger.info(f"Generated token length: {len(predicted_ids[0])} tokens") + logger.info(f"Transcription: {transcription[0]}") + measure_fn(args, generate_fn, inputs) + + +def run_ort_inference(args, inputs, model): + def prepare_ort_inputs(inputs, warmup=False): + # Check that all model inputs will be provided + model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs())) + user_inputs = set(inputs.keys()) + missing_inputs = model_inputs - user_inputs + if len(missing_inputs): + logger.error(f"The following model inputs are missing: {missing_inputs}") + raise Exception("There are missing inputs to the model. Please add them and try again.") + + if warmup and args.tune: + inputs["min_length"] = inputs["max_length"] + + # Remove unnecessary inputs from model inputs + unnecessary_inputs = user_inputs - model_inputs + if len(unnecessary_inputs): + for unnecessary_input in unnecessary_inputs: + logger.info(f"Removing unnecessary input '{unnecessary_input}' from user provided inputs") + del inputs[unnecessary_input] + + # Add IO bindings for non-CPU execution providers + if args.device != "cpu": + io_binding = model.io_binding() + for k, v in inputs.items(): + io_binding.bind_cpu_input(k, v) + for output in model.get_outputs(): + io_binding.bind_output(output.name, device_type=args.device, device_id=args.device_id) + return io_binding + + return inputs + + def with_io_binding(io_binding): + # Inference pass with IO binding + model.run_with_iobinding(io_binding) + return io_binding + + def without_io_binding(inputs): + # Inference pass without IO binding + outputs = model.run(None, inputs) + return outputs + + def handle_output(output): + if args.eos_token_id in output: + first_end = np.where(output == args.eos_token_id)[0][0] + return output[: first_end + 1] + + return output + + generate_fn = with_io_binding if args.device != "cpu" else without_io_binding + ort_inputs = prepare_ort_inputs(inputs) + + if args.profile: + new_logname = profile_fn(args, generate_fn, ort_inputs, "e2e") + + # Turn profiling off to stop appending to log file + old_logname = model.end_profiling() + logger.warning(f"Renaming {old_logname} to {new_logname}") + os.rename(old_logname, os.path.join(args.log_folder, new_logname)) + + return + + # ORT evaluation + logger.info("\nEvaluating ONNX Runtime...") + ort_evaluate_inputs = ort_inputs + if args.tune: + ort_warmup_inputs = prepare_ort_inputs(inputs, warmup=True) + ort_evaluate_inputs = (ort_warmup_inputs, ort_inputs) + + time_fn(args, generate_fn, ort_evaluate_inputs) + ort_outputs = generate_fn(ort_inputs) + if args.device != "cpu": + ort_outputs = ort_outputs.copy_outputs_to_cpu() + ort_outputs = ort_outputs[0] + + if args.has_audio_stream: + # ONNX E2E model from Olive produces transcribed output + logger.info(f"Transcription: {ort_outputs[0][0]}") + else: + # convert_to_onnx model produces generated ids + actual_output = handle_output(ort_outputs[0][0]) + logger.info(f"Generated token length: {len(actual_output)} tokens") + transcription = args.processor.batch_decode(ort_outputs[0], skip_special_tokens=True)[0] + logger.info(f"Transcription: {transcription}") + + measure_fn(args, generate_fn, ort_inputs) + + +def run_inference(args, inputs, model): + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}: + run_hf_inference(args, inputs, model) + elif args.benchmark_type == "ort": + run_ort_inference(args, inputs, model) + else: + raise Exception(f"Cannot recognize {args.benchmark_type}") + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-bt", + "--benchmark-type", + type=str, + required=True, + choices=["hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"], + ) + + parser.add_argument( + "-m", + "--model-name", + type=str, + required=True, + help="Hugging Face name of model (e.g. 'openai/whisper-large-v2')", + ) + parser.add_argument( + "-p", + "--precision", + type=str, + required=True, + default="fp32", + choices=["int8", "fp16", "fp32"], + help="Precision for model. For ONNX models, the model's precision should be set before running this script.", + ) + + parser.add_argument( + "--hf-pt-model-path", + type=str, + default="", + help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)", + ) + parser.add_argument( + "--hf-ort-dir-path", + type=str, + default="", + help="Path to directory containing all ONNX files (e.g. tokenizer, encoder, decoder, decoder_with_past)", + ) + parser.add_argument( + "--ort-model-path", + type=str, + default="", + help="Path to ONNX model", + ) + + # Args for running and evaluating the model + parser.add_argument("-a", "--audio-path", type=str, required=True, help="Path to audio file for E2E evaluation") + parser.add_argument( + "-d", + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + choices=["cpu", "cuda", "rocm"], + ) + parser.add_argument("-id", "--device-id", type=int, default=0) + parser.add_argument("-w", "--warmup-runs", type=int, default=5) + parser.add_argument("-n", "--num-runs", type=int, default=10) + parser.add_argument("--seed", type=int, default=2) + + # Optional args: + parser.add_argument("--sampling-rate", type=int, default=16000, help="Sampling rate for audio (in Hz)") + + # Args for decoding logic + # Required args: + parser.add_argument("--max-length", type=int, default=448) + parser.add_argument("--min-length", type=int, default=0) + parser.add_argument("--num-beams", type=int, default=1) + parser.add_argument("--num-return-sequences", type=int, default=1) + parser.add_argument("--length-penalty", type=float, default=1.0) + parser.add_argument("--repetition-penalty", type=float, default=1.0) + parser.add_argument("--no-repeat-ngram-size", type=int, default=3) + + # Optional args for E2E solution: + parser.add_argument( + "--decoder-input-ids", + type=str, + default="[]", + help="The forced decoder ids for generation. Format is [start token, timestamp token, language token, task token]. Default is [start token]. See `decoder_input_ids` in https://github.com/microsoft/Olive/tree/main/examples/whisper for details.", + ) + parser.add_argument( + "--logits-processor", + type=int, + default=1, + help="Type of logits processor to use. See `BeamSearch` in https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/contrib_ops/contrib_defs.cc for details.", + ) + + # Args for accessing detailed info + parser.add_argument("--profile", default=False, action="store_true") + parser.add_argument( + "--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by" + ) + parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display") + parser.add_argument("--verbose", default=False, action="store_true") + parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files") + parser.add_argument( + "--tune", + default=False, + action="store_true", + help="Only used by ROCm EP, enable TunableOp tuning to select fastest kernel", + ) + + args = parser.parse_args() + + # Set seed properties + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + args.monitor_type = args.device + # Set runtime properties + if "ort" in args.benchmark_type: + args.execution_provider = f"{args.device.upper()}ExecutionProvider" + if args.execution_provider == "CUDAExecutionProvider": + args.execution_provider = (args.execution_provider, {"device_id": args.device_id}) + elif args.execution_provider == "ROCMExecutionProvider": + args.execution_provider = ( + args.execution_provider, + { + "device_id": args.device_id, + "tunable_op_enable": 1, + "tunable_op_tuning_enable": 1 if args.tune else 0, + }, + ) + args.device = "cuda" + + # Check that model paths have been specified for any benchmarking with ORT + if args.benchmark_type == "hf-ort": + assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`" + if args.benchmark_type == "ort": + assert args.ort_model_path, "Please specify a path to `--ort-model-path`" + + # Convert decoder_input_ids string to list of ids + # (e.g. "[1, 50257]" for Hugging Face or "[50257]" for ORT) + args.decoder_input_ids = ast.literal_eval(args.decoder_input_ids) + + return args + + +def main(): + args = parse_args() + setup_logger(args.verbose) + logger.info(args.__dict__) + torch.backends.cudnn.benchmark = True + + config = WhisperConfig.from_pretrained(args.model_name) + processor = WhisperProcessor.from_pretrained(args.model_name) + target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device + use_fp16 = args.precision == "fp16" + + setattr(args, "processor", processor) # noqa: B010 + setattr(args, "target_device", target_device) # noqa: B010 + setattr(args, "use_fp16", use_fp16) # noqa: B010 + setattr(args, "has_audio_stream", False) # noqa: B010 + setattr(args, "eos_token_id", config.eos_token_id) # noqa: B010 + + logger.info(f"Forced decoder prompt ids: {args.decoder_input_ids}") + + # Measure cost to transcribe audio + model = get_model(args) + if args.benchmark_type == "ort": + # Check for optional inputs that could have been added during export + ort_model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs())) + args.has_audio_stream = "audio_stream" in ort_model_inputs + setattr(args, "has_decoder_input_ids", "decoder_input_ids" in ort_model_inputs) # noqa: B010 + setattr(args, "has_logits_processor", "logits_processor" in ort_model_inputs) # noqa: B010 + + if args.decoder_input_ids == []: + args.decoder_input_ids = [config.decoder_start_token_id] + + inputs = get_inputs(args) + run_inference(args, inputs, model) + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py new file mode 100644 index 0000000000..071b539ac1 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py @@ -0,0 +1,455 @@ +import argparse +import datetime +import json +import logging +import os +import subprocess + +import librosa +import torch +from benchmark_helper import setup_logger +from transformers import WhisperConfig, WhisperProcessor + +logger = logging.getLogger(__name__) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-a", + "--audio-path", + type=str, + required=True, + help="Path to folder of audio files for E2E evaluation", + ) + + parser.add_argument( + "-l", + "--language", + default=None, + help="Language of audio file", + ) + + parser.add_argument( + "-t", + "--task", + default=None, + choices=["transcribe", "translate"], + help="Task to complete", + ) + + parser.add_argument( + "-w", + "--warmup-runs", + type=int, + default=5, + ) + + parser.add_argument( + "-n", + "--num-runs", + type=int, + default=10, + ) + + parser.add_argument( + "--hf-pt-eager", + default=False, + action="store_true", + help="Benchmark in PyTorch without `torch.compile`", + ) + + parser.add_argument( + "--hf-pt-compile", + default=False, + action="store_true", + help="Benchmark in PyTorch with `torch.compile`", + ) + + parser.add_argument( + "--hf-ort-dir-path", + type=str, + help="Path to folder containing ONNX models for Optimum + ORT benchmarking", + ) + + parser.add_argument( + "--ort-model-path", + type=str, + help="Path to ONNX model for ORT benchmarking", + ) + + parser.add_argument( + "--model-name", + type=str, + required=True, + help="Model name in Hugging Face (e.g. openai/whisper-large-v2)", + ) + + parser.add_argument( + "--precision", + type=str, + required=True, + choices=["int8", "fp16", "fp32"], + help="Precision to run model", + ) + + parser.add_argument( + "--device", + type=str, + required=True, + choices=["cpu", "cuda", "rocm"], + help="Device to benchmark models", + ) + + parser.add_argument( + "--device-id", + type=int, + default=0, + help="GPU device ID", + ) + + parser.add_argument( + "--verbose", + default=False, + action="store_true", + help="Print detailed logs", + ) + + parser.add_argument( + "--timeout", + type=int, + default=5, + help="Number of mins to attempt the benchmark before moving on", + ) + + parser.add_argument("--tune", default=False, action="store_true") + + args = parser.parse_args() + + setattr(args, "model_size", args.model_name.split("/")[-1].replace(".", "-")) # noqa: B010 + log_folder_name = f"./{args.model_size}-{args.precision}" + setattr(args, "log_folder", log_folder_name) # noqa: B010 + os.makedirs(args.log_folder, exist_ok=True) + + # Convert timeout value to secs + args.timeout *= 60 + + return args + + +def process_log_file(device_id, log_file, base_results): + entries = [] + + # Detect steps in speech pipeline + step = None + load_audio_pattern = "Load audio: " + feat_ext_pattern = "Feature extraction: " + pytorch_pattern = "Evaluating PyTorch..." + onnxruntime_pattern = "Evaluating ONNX Runtime..." + + load_audio_latency_s, load_audio_throughput_s = None, None + feat_ext_latency_s, feat_ext_throughput_s = None, None + token_length, latency_s, per_token_latency_s, per_token_latency_ms = None, None, None, None + throughput, memory = None, None + + # Detect metrics + latency_pattern = "Latency: " + throughput_pattern = "Throughput: " + token_length_pattern = "Generated token length: " + memory_pattern = "peak=" + + with open(log_file) as f: + for input_line in f: + line = input_line.replace("\n", "") + + # Get step in speech recognition pipeline + if load_audio_pattern in line: + step = "load-audio" + elif feat_ext_pattern in line: + step = "feature-extraction" + elif pytorch_pattern in line or onnxruntime_pattern in line: + step = "process" + + # Check metrics + if latency_pattern in line: + latency_s = float(line[len(latency_pattern) : line.rfind(" ")]) + elif throughput_pattern in line: + throughput = float(line[len(throughput_pattern) : line.rfind(" ")]) + if step == "load-audio": + load_audio_latency_s, load_audio_throughput_s = latency_s, throughput + step = None + if step == "feature-extraction": + feat_ext_latency_s, feat_ext_throughput_s = latency_s, throughput + step = None + elif token_length_pattern in line: + token_length = int(line[len(token_length_pattern) : line.rfind(" ")]) + per_token_latency_s = latency_s / token_length + per_token_latency_ms = per_token_latency_s * 1000 + elif memory_pattern in line: + if "CPU" in line: + # Example format for log entry: + # CPU memory usage: before=1000.0 MB, peak=2000.0 MB + memory = float(line[line.rfind("=") + 1 : line.rfind(" MB")]) / 1000 + else: + # Example format for log entry: + # GPU memory usage: before=[{'device_id': 0, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 1638.875}, {'device_id': 1, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 236.875}, peak=[{'device_id': 0, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 1780.875}, {'device_id': 1, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 236.875}] + peak = line[line.find(memory_pattern) + len(memory_pattern) :].replace("'", '"') + usage = json.loads(peak)[device_id]["max_used_MB"] + memory = float(usage) / 1000 + + # Calculate real-time factor (RTF): + # RTF = total latency / audio duration + total_latency = ( + (load_audio_latency_s if load_audio_latency_s else 0) + + (feat_ext_latency_s if feat_ext_latency_s else 0) + + (latency_s if latency_s else 0) + ) + audio_duration = base_results[-1] + rtf = (total_latency / audio_duration) if audio_duration else -1 + logger.info(f"Total latency: {total_latency} s") + logger.info(f"Audio duration: {audio_duration} s") + logger.info(f"Real-time factor: {rtf}") + + # Append log entry to list of entries + entry = base_results + [ # noqa: RUF005 + token_length, + load_audio_latency_s, + load_audio_throughput_s, + feat_ext_latency_s if feat_ext_latency_s else -1, + feat_ext_throughput_s if feat_ext_throughput_s else -1, + latency_s, + per_token_latency_ms, + throughput, + memory, + rtf, + ] + entries.append(entry) + + return entries + + +def save_results(results, filename): + import pandas as pd + + df = pd.DataFrame( + results, + columns=[ + "Engine", + "Precision", + "Device", + "Audio File", + "Duration (s)", + "Token Length", + "Load Audio Latency (s)", + "Load Audio Throughput (qps)", + "Feature Extractor Latency (s)", + "Feature Extractor Throughput (qps)", + "Latency (s)", + "Per Token Latency (ms/token)", + "Throughput (qps)", + "Memory (GB)", + "Real Time Factor (RTF)", + ], + ) + + # Set column types + df["Duration (s)"] = df["Duration (s)"].astype("float") + df["Token Length"] = df["Token Length"].astype("int") + df["Load Audio Latency (s)"] = df["Load Audio Latency (s)"].astype("float") + df["Load Audio Throughput (qps)"] = df["Load Audio Throughput (qps)"].astype("float") + df["Feature Extractor Latency (s)"] = df["Feature Extractor Latency (s)"].astype("float") + df["Feature Extractor Throughput (qps)"] = df["Feature Extractor Throughput (qps)"].astype("float") + df["Latency (s)"] = df["Latency (s)"].astype("float") + df["Per Token Latency (ms/token)"] = df["Per Token Latency (ms/token)"].astype("float") + df["Throughput (qps)"] = df["Throughput (qps)"].astype("float") + df["Memory (GB)"] = df["Memory (GB)"].astype("float") + df["Real Time Factor (RTF)"] = df["Real Time Factor (RTF)"].astype("float") + + df.to_csv(filename, index=False) + logger.info(f"Results saved in {filename}!") + + +def benchmark(args, benchmark_cmd, engine, audio_file, duration): + log_filename = f"{engine}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.log" + log_path = os.path.join(args.log_folder, log_filename) + with open(log_path, "w") as log_file: + process = subprocess.Popen(benchmark_cmd, stdout=log_file, stderr=log_file) + try: + process.wait(args.timeout) + except subprocess.TimeoutExpired: + process.kill() + + # Create entries for csv + logger.info("Gathering data from log files...") + base_results = [engine, args.precision, args.device, audio_file, duration] + results = process_log_file(args.device_id, log_path, base_results) + + return results + + +def main(): + args = get_args() + setup_logger(args.verbose) + logger.info(args.__dict__) + torch.backends.cudnn.benchmark = True + + config = WhisperConfig.from_pretrained(args.model_name) + processor = WhisperProcessor.from_pretrained(args.model_name) + + # Calculate forced decoder input ids + hf_forced_decoder_ids = processor.get_decoder_prompt_ids(language=args.language, task=args.task) + ort_forced_decoder_ids = [config.decoder_start_token_id] + list( # noqa: RUF005 + map(lambda token_id: token_id[1], hf_forced_decoder_ids) + ) + hf_decoder_input_ids_cmd = ( + ["--decoder-input-ids", str(hf_forced_decoder_ids)] if args.language and args.task else [] + ) + ort_decoder_input_ids_cmd = ( + ["--decoder-input-ids", str(ort_forced_decoder_ids)] if args.language and args.task else [] + ) + ort_tune_cmd = ["--tune"] if args.tune else [] + + all_results = [] + for audio_file in os.listdir(args.audio_path): + audio_path = os.path.join(args.audio_path, audio_file) + try: + duration = librosa.get_duration(path=audio_path) + except Exception as e: + duration = -1 + logger.warning(f"An error occurred while trying to calculate the audio duration: {e}", exc_info=True) + logger.warning( + f"If you get an error that says:\n\tsoundfile.LibsndfileError: Error opening '{audio_file}': File contains data in an unknown format.\nyou may not have installed `ffmpeg` in addition to installing `librosa`." + ) + logger.info(f"Testing {audio_path}...") + + # Benchmark PyTorch without torch.compile + if args.hf_pt_eager: + benchmark_cmd = [ # noqa: RUF005 + "python", + "-m", + "models.whisper.benchmark", + "--audio-path", + audio_path, + "--benchmark-type", + "hf-pt-eager", + "--model-name", + args.model_name, + "--precision", + args.precision, + "--device", + args.device, + "--device-id", + str(args.device_id), + "--warmup-runs", + str(args.warmup_runs), + "--num-runs", + str(args.num_runs), + "--log-folder", + args.log_folder, + ] + hf_decoder_input_ids_cmd + logger.info("Benchmark PyTorch without torch.compile") + results = benchmark(args, benchmark_cmd, "pytorch-eager", audio_file, duration) + all_results.extend(results) + + # Benchmark PyTorch with torch.compile + if args.hf_pt_compile: + benchmark_cmd = [ # noqa: RUF005 + "python", + "-m", + "models.whisper.benchmark", + "--audio-path", + audio_path, + "--benchmark-type", + "hf-pt-compile", + "--model-name", + args.model_name, + "--precision", + args.precision, + "--device", + args.device, + "--device-id", + str(args.device_id), + "--warmup-runs", + str(args.warmup_runs), + "--num-runs", + str(args.num_runs), + "--log-folder", + args.log_folder, + ] + hf_decoder_input_ids_cmd + logger.info("Benchmark PyTorch with torch.compile") + results = benchmark(args, benchmark_cmd, "pytorch-compile", audio_file, duration) + all_results.extend(results) + + # Benchmark Optimum + ONNX Runtime + if args.hf_ort_dir_path: + benchmark_cmd = [ # noqa: RUF005 + "python", + "-m", + "models.whisper.benchmark", + "--audio-path", + audio_path, + "--benchmark-type", + "hf-ort", + "--hf-ort-dir-path", + args.hf_ort_dir_path, + "--model-name", + args.model_name, + "--precision", + args.precision, + "--device", + args.device, + "--device-id", + str(args.device_id), + "--warmup-runs", + str(args.warmup_runs), + "--num-runs", + str(args.num_runs), + "--log-folder", + args.log_folder, + ] + hf_decoder_input_ids_cmd + logger.info("Benchmark Optimum + ONNX Runtime") + results = benchmark(args, benchmark_cmd, "optimum-ort", audio_file, duration) + all_results.extend(results) + + # Benchmark ONNX Runtime + if args.ort_model_path: + benchmark_cmd = ( + [ # noqa: RUF005 + "python", + "-m", + "models.whisper.benchmark", + "--audio-path", + audio_path, + "--benchmark-type", + "ort", + "--ort-model-path", + args.ort_model_path, + "--model-name", + args.model_name, + "--precision", + args.precision, + "--device", + args.device, + "--device-id", + str(args.device_id), + "--warmup-runs", + str(args.warmup_runs), + "--num-runs", + str(args.num_runs), + "--log-folder", + args.log_folder, + ] + + ort_decoder_input_ids_cmd + + ort_tune_cmd + ) + logger.info("Benchmark ONNX Runtime") + results = benchmark(args, benchmark_cmd, "onnxruntime", audio_file, duration) + all_results.extend(results) + + csv_file = f"{args.model_size}-{args.precision}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv" + save_results(all_results, os.path.join(args.log_folder, csv_file)) + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index 2821f6b89b..3562df1660 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -180,22 +180,25 @@ def parse_arguments(argv=None): "--quantize_embedding_layer", required=False, action="store_true", - help="Produce beam search model with chained encdecinit and decoder.", + help="Quantize MatMul, GEMM, and Gather.", ) + parser.set_defaults(quantize_embedding_layer=False) parser.add_argument( "--quantize_per_channel", required=False, action="store_true", - help="Produce beam search model with chained encdecinit and decoder.", + help="Quantize weights per each channel.", ) + parser.set_defaults(quantize_per_channel=False) parser.add_argument( "--quantize_reduce_range", required=False, action="store_true", - help="Produce beam search model with chained encdecinit and decoder.", + help="Quantize weights with 7 bits.", ) + parser.set_defaults(quantize_reduce_range=False) parser.add_argument("--no_repeat_ngram_size", type=int, default=0, help="default to 0") diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index 7e2325c148..3b1e656136 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -135,7 +135,7 @@ def chain_model(args): # Initializers/opsets # Delete shared data between decoder/encoder and move to larger graph initializers - initializers = get_shared_initializers(encoder_model, decoder_model, require_raw_data=True) + initializers = get_shared_initializers(encoder_model, decoder_model) node.attribute.extend( [ helper.make_attribute("decoder", decoder_model.graph), diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 4f74da577d..5fda3e6d84 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -23,6 +23,7 @@ numpy_helper, save_model, ) +from onnx.external_data_helper import load_external_data_for_tensor, uses_external_data from shape_infer_helper import SymbolicShapeInferenceHelper logger = logging.getLogger(__name__) @@ -336,6 +337,18 @@ def match_parent_paths(self, node, paths, output_name_to_node): return i, matched, return_indice return -1, None, None + def match_parent_paths_all(self, node, paths, output_name_to_node): + match_i, matches, return_indices = [], [], [] + for i, path in enumerate(paths): + assert isinstance(path, (List, Tuple)) + return_indice = [] + matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice) + if matched: + match_i.append(i) + matches.append(matched) + return_indices.append(return_indice) + return match_i, matches, return_indices + def match_parent_path( self, node, @@ -609,7 +622,7 @@ def convert_float_to_float16(self, use_symbolic_shape_infer=True, **kwargs): When symbolic shape inference is used (even if it failed), ONNX shape inference will be disabled. - Note that onnx shape inference will fail for model larger than 2GB. For large model, you have to eanble + Note that onnx shape inference will fail for model larger than 2GB. For large model, you have to enable symbolic shape inference. If your model is not optimized, you can also use model path to call convert_float_to_float16 in float16.py (see https://github.com/microsoft/onnxruntime/pull/15067) to avoid the 2GB limit. @@ -815,51 +828,77 @@ def prune_graph(self, outputs=None, allow_remove_graph_inputs=True): """ if len(self.graphs()) > 1: + # TODO(tianleiwu): handle subgraph logger.debug("Skip prune_graph since graph has subgraph") return - if outputs is None: - outputs = [output.name for output in self.model.graph.output] + keep_outputs = [output.name for output in self.model.graph.output] if outputs is None else outputs output_name_to_node = self.output_name_to_node() - all_nodes = [] - for output in outputs: - if output in output_name_to_node: - last_node = output_name_to_node[output] - if last_node in all_nodes: - continue - nodes = self.get_parent_subgraph_nodes(last_node, []) - all_nodes.append(last_node) - all_nodes.extend(nodes) - nodes_to_remove = [node for node in self.model.graph.node if node not in all_nodes] + def get_first_output(node): + if node.output[0]: + return node.output[0] + return next(iter([o for o in node.output if o]), None) - self.remove_nodes(nodes_to_remove) + # Keep track of nodes to keep. The key is first output of node, and the value is the node. + output_to_node = {} - # remove outputs not in list - output_to_remove = [] - for output in self.model.graph.output: - if output.name not in outputs: - output_to_remove.append(output) - for output in output_to_remove: - self.model.graph.output.remove(output) + # Start from graph outputs, and find parent nodes recursively, and add nodes to the output_to_node dictionary. + dq = deque() + for output in keep_outputs: + if output in output_name_to_node: + dq.append(output_name_to_node[output]) + while len(dq) > 0: + node = dq.pop() + first_output = get_first_output(node) + if first_output and (first_output not in output_to_node): + output_to_node[first_output] = node + for name in node.input: + if len(name) > 0 and (name in output_name_to_node) and (name not in output_to_node): + dq.appendleft(output_name_to_node[name]) + + # Keep only those nodes in the output_to_node dictionary. + nodes_to_keep = [] + num_nodes_removed = 0 + for node in self.model.graph.node: + first_output = get_first_output(node) + kept_node = output_to_node[first_output] if first_output in output_to_node else None + + # Need double check the node since fused node might reuse output name of some nodes to be removed. + # It is slow to compare whole node, so we compare op_type first to avoid comparing node in most cases. + if kept_node and kept_node.op_type == node.op_type and kept_node == node: + nodes_to_keep.append(node) + else: + num_nodes_removed += 1 + self.model.graph.ClearField("node") + self.model.graph.node.extend(nodes_to_keep) - # remove inputs not used by any node. + # Remove graph outputs not in list + output_to_remove = [] + if outputs is not None: + for output in self.model.graph.output: + if output.name not in outputs: + output_to_remove.append(output) + for output in output_to_remove: + self.model.graph.output.remove(output) + + # Remove graph inputs not used by any node. input_to_remove = [] if allow_remove_graph_inputs: input_name_to_nodes = self.input_name_to_nodes() input_to_remove = [input for input in self.model.graph.input if input.name not in input_name_to_nodes] - for input in input_to_remove: - self.model.graph.input.remove(input) + for name in input_to_remove: + self.model.graph.input.remove(name) - if input_to_remove or output_to_remove or nodes_to_remove: + if input_to_remove or output_to_remove or num_nodes_removed > 0: removed = [] if input_to_remove: removed.append(f"{len(input_to_remove)} inputs") if output_to_remove: removed.append(f"{len(output_to_remove)} outputs") - if nodes_to_remove: - removed.append(f"{len(nodes_to_remove)} nodes") + if num_nodes_removed > 0: + removed.append(f"{num_nodes_removed} nodes") logger.info("Removed %s", ", ".join(removed)) self.update_graph() @@ -1087,33 +1126,78 @@ def get_operator_statistics(self, include_domain=False): op = (node.domain + ":" if include_domain and node.domain else "") + node.op_type op_count[op] = 1 if op not in op_count else (op_count[op] + 1) - logger.info(f"Operators:{op_count}") + # Sorted by count in the descending order, then by key in alphabetical order. + logger.info(f"Operators:{sorted(op_count.items(), key=lambda kv:(-kv[1], kv[0]))}") + return op_count @staticmethod - def has_same_value(tensor1: TensorProto, tensor2: TensorProto, require_raw_data: bool = False) -> bool: + def to_data_hash(tensor: TensorProto, base_dir: str = "") -> int: + """Converts a tensor def object to a hash for data comparison purposes. + Args: + tensor: a TensorProto object. + base_dir: if external tensor exists, base_dir can help to find the path to it + Returns: + hash: a hash of the data. + """ + if tensor.HasField("segment"): + raise ValueError("Currently not supporting loading segments.") + if tensor.data_type == TensorProto.UNDEFINED: + raise TypeError("The element type in the input tensor is not defined.") + tensor_dtype = tensor.data_type + storage_field = helper.tensor_dtype_to_field(tensor_dtype) + + if tensor.data_type == TensorProto.STRING: + utf8_strings = getattr(tensor, storage_field) + return hash(tuple(s.decode("utf-8") for s in utf8_strings)) + # Load raw data from external tensor if it exists + if uses_external_data(tensor): + load_external_data_for_tensor(tensor, base_dir) + if tensor.HasField("raw_data"): + return hash(tensor.raw_data) + else: + np_data = numpy_helper.to_array(tensor) + return hash(np_data.tobytes()) + + @staticmethod + def has_same_value( + tensor1: TensorProto, + tensor2: TensorProto, + signature_cache1: Optional[dict] = None, + signature_cache2: Optional[dict] = None, + ) -> bool: """Returns True when two tensors have same value. Note that name can be different. Args: tensor1 (TensorProto): initializer 1 tensor2 (TensorProto): initializer 2 - require_raw_data (bool): ignore tensors without raw_data - Note: Flag can speed up runtime significantly - + signature_cache1 (dict): Optional dictionary to store data signatures of tensor1 in order to speed up comparison. + signature_cache2 (dict): Optional dictionary to store data signatures of tensor2 in order to speed up comparison. Returns: - bool: True when two intializers has same value. + bool: True when two initializers has same value. """ - if tensor1.data_type != tensor2.data_type or tensor1.dims != tensor2.dims: - return False - if tensor1.HasField("raw_data") and tensor2.HasField("raw_data"): - return tensor1.raw_data == tensor2.raw_data - if require_raw_data: - return False + sig1 = ( + signature_cache1[tensor1.name] + if signature_cache1 and tensor1.name in signature_cache1 + else OnnxModel.to_data_hash(tensor1) + ) + sig2 = ( + signature_cache2[tensor2.name] + if signature_cache2 and tensor2.name in signature_cache2 + else OnnxModel.to_data_hash(tensor2) + ) + if signature_cache1 is not None: + signature_cache1[tensor1.name] = sig1 + if signature_cache2 is not None: + signature_cache2[tensor2.name] = sig2 + if sig1 == sig2 and tensor1.data_type == tensor2.data_type and tensor1.dims == tensor2.dims: + # Same signature, now do the expensive check to confirm the data is the same + return (numpy_helper.to_array(tensor1) == numpy_helper.to_array(tensor2)).all() - return (numpy_helper.to_array(tensor1) == numpy_helper.to_array(tensor2)).all() + return False - def remove_duplicated_initializer(self, require_raw_data: bool = False): + def remove_duplicated_initializer(self, cache: Optional[dict] = None): """Remove initializers with duplicated values, and only keep the first one. It could help reduce size of models (like ALBert) with shared weights. If require_raw_data passed, method will only compare raw_data initializers to speed runtime @@ -1130,7 +1214,7 @@ def remove_duplicated_initializer(self, require_raw_data: bool = False): continue for j in range(i + 1, initializer_count): if OnnxModel.has_same_value( - self.model.graph.initializer[i], self.model.graph.initializer[j], require_raw_data + self.model.graph.initializer[i], self.model.graph.initializer[j], cache, cache ): same[j] = i diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 995f8c6541..882100a0d0 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -22,7 +22,9 @@ from fusion_qordered_layernorm import FusionQOrderedLayerNormalization from fusion_qordered_matmul import FusionQOrderedMatMul from fusion_reshape import FusionReshape +from fusion_rotary_attention import FusionRotaryEmbeddings from fusion_shape import FusionShape +from fusion_simplified_layernorm import FusionSimplifiedLayerNormalization, FusionSkipSimplifiedLayerNormalization from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization from fusion_utils import FusionUtils from onnx import GraphProto, ModelProto, TensorProto, ValueInfoProto, helper @@ -106,10 +108,36 @@ def fuse_layer_norm(self): fusion = FusionQOrderedLayerNormalization(self) fusion.apply() + def fuse_simplified_layer_norm(self): + fusion = FusionSimplifiedLayerNormalization(self) + fusion.apply() + def fuse_skip_layer_norm(self): fusion = FusionSkipLayerNormalization(self) fusion.apply() + def fuse_skip_simplified_layer_norm(self): + fusion = FusionSkipSimplifiedLayerNormalization(self) + fusion.apply() + + def fuse_rotary_embeddings(self): + fusion = FusionRotaryEmbeddings(self) + fusion.apply() + # Remove non-MS domain functions + rot_emb_nodes = list( + filter( + lambda node: node.op_type == "RotaryEmbedding" and node.domain != "com.microsoft", self.model.graph.node + ) + ) + non_ms_domains_to_keep = set(map(lambda node: node.domain, rot_emb_nodes)) + i = 0 + while i < len(self.model.functions): + fn = self.model.functions[i] + if "RotaryEmbedding" in fn.name and fn.domain not in non_ms_domains_to_keep: + self.model.functions.remove(fn) + else: + i += 1 + # Only relevant in models with Q-DQ nodes def fuse_qordered_mamtul(self): fusion = FusionQOrderedMatMul(self) @@ -367,6 +395,7 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo if (options is None) or options.enable_layer_norm: self.fuse_layer_norm() + self.fuse_simplified_layer_norm() if (options is None) or options.enable_gelu: self.fuse_gelu() @@ -377,6 +406,10 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo if (options is None) or options.enable_skip_layer_norm: self.fuse_skip_layer_norm() + self.fuse_skip_simplified_layer_norm() + + if (options is None) or options.enable_rotary_embeddings: + self.fuse_rotary_embeddings() if options is not None: self.attention_mask.set_mask_format(options.attention_mask_format) @@ -442,38 +475,56 @@ def get_fused_operator_statistics(self): "BiasGelu", "GemmFastGelu", "LayerNormalization", + "SimplifiedLayerNormalization", "SkipLayerNormalization", + "SkipSimplifiedLayerNormalization", + "RotaryEmbedding", ] q_ops = ["QOrderedAttention", "QOrderedGelu", "QOrderedLayerNormalization", "QOrderedMatMul"] for op in ops + q_ops: nodes = self.get_nodes_by_op_type(op) op_count[op] = len(nodes) - logger.info(f"Optimized operators:{op_count}") + logger.info(f"Optimized operators: {op_count}") return op_count - def is_fully_optimized(self): + def is_fully_optimized(self, fused_op_count=None): """ Returns True when the model is fully optimized. """ - op_count = self.get_fused_operator_statistics() - embed = op_count["EmbedLayerNormalization"] - attention = op_count["Attention"] + op_count["MultiHeadAttention"] + op_count["QOrderedAttention"] - gelu = op_count["Gelu"] + op_count["BiasGelu"] + op_count["FastGelu"] - layer_norm = op_count["LayerNormalization"] + op_count["SkipLayerNormalization"] - is_perfect = (embed > 0) and (attention > 0) and (attention == gelu) and (layer_norm >= 2 * attention) + if fused_op_count is None: + fused_op_count = self.get_fused_operator_statistics() + + def op_count(op_name: str): + return fused_op_count.get(op_name) or 0 + + embed = op_count("EmbedLayerNormalization") + attention = op_count("Attention") + op_count("MultiHeadAttention") + op_count("QOrderedAttention") + gelu = op_count("Gelu") + op_count("BiasGelu") + op_count("FastGelu") + layer_norm = op_count("LayerNormalization") + op_count("SkipLayerNormalization") + simple_layer_norm = op_count("SimplifiedLayerNormalization") + op_count("SkipSimplifiedLayerNormalization") + + is_perfect = ( + (embed > 0) + and (attention > 0) + and (attention == gelu) + and ((layer_norm >= 2 * attention) or (simple_layer_norm >= 2 * attention)) + ) if layer_norm == 0: logger.debug("Layer Normalization not fused") + if simple_layer_norm == 0: + logger.debug("Simple Layer Normalization not fused") + if gelu == 0: - logger.debug("Gelu/FastGelu not fused") + logger.debug("Gelu (or FastGelu) not fused") if embed == 0: - logger.debug("Embed Layer not fused") + logger.debug("EmbedLayerNormalization not fused") if attention == 0: - logger.warning("Attention not fused") + logger.warning("Attention (or MultiHeadAttention) not fused") return is_perfect diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py b/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py index 1229825fec..c781a91c9e 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py @@ -435,7 +435,7 @@ def remove_extra_reshape_2(self): "SkipLayerNormalization", ], [None, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ) # yapf: disable + ) if path is None: continue diff --git a/onnxruntime/python/tools/transformers/onnx_model_clip.py b/onnxruntime/python/tools/transformers/onnx_model_clip.py index 93e8623768..9b4ca03a47 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_clip.py +++ b/onnxruntime/python/tools/transformers/onnx_model_clip.py @@ -5,15 +5,17 @@ from logging import getLogger +from fusion_attention_clip import FusionAttentionClip from onnx import ModelProto -from onnx_model_unet import UnetOnnxModel +from onnx_model_bert import BertOnnxModel logger = getLogger(__name__) -class ClipOnnxModel(UnetOnnxModel): +class ClipOnnxModel(BertOnnxModel): def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0): super().__init__(model, num_heads=num_heads, hidden_size=hidden_size) + self.clip_attention_fusion = FusionAttentionClip(self, self.hidden_size, self.num_heads) def get_fused_operator_statistics(self): """ @@ -31,3 +33,6 @@ def get_fused_operator_statistics(self): logger.info(f"Optimized operators:{op_count}") return op_count + + def fuse_attention(self): + self.clip_attention_fusion.apply() diff --git a/onnxruntime/python/tools/transformers/onnx_model_gpt2.py b/onnxruntime/python/tools/transformers/onnx_model_gpt2.py index 263857ffbc..6545bb08cd 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_gpt2.py +++ b/onnxruntime/python/tools/transformers/onnx_model_gpt2.py @@ -8,6 +8,7 @@ from fusion_gpt_attention import FusionGptAttention from fusion_gpt_attention_megatron import FusionGptAttentionMegatron from fusion_gpt_attention_no_past import FusionGptAttentionNoPast +from fusion_rotary_attention import FusionRotaryAttention from onnx_model_bert import BertOnnxModel logger = logging.getLogger(__name__) @@ -27,6 +28,9 @@ def fuse_attention(self): fusion = FusionGptAttentionMegatron(self, self.num_heads) fusion.apply() + fusion = FusionRotaryAttention(self, self.hidden_size, self.num_heads) + fusion.apply() + def postprocess(self): """ Remove extra reshape nodes. @@ -94,4 +98,4 @@ def postprocess(self): reshape_count += 2 self.prune_graph() - logger.info(f"postprocess: remove Reshape count:{reshape_count}") + logger.info(f"postprocess: remove Reshape count: {reshape_count}") diff --git a/onnxruntime/python/tools/transformers/onnx_model_t5.py b/onnxruntime/python/tools/transformers/onnx_model_t5.py index 8fb31da4a6..95f40af3fd 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_t5.py +++ b/onnxruntime/python/tools/transformers/onnx_model_t5.py @@ -3,12 +3,12 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging -from typing import Dict, Optional, Union +from typing import Optional, Union import numpy as np from fusion_attention import AttentionMask, FusionAttention from fusion_base import Fusion -from fusion_skiplayernorm import FusionSkipLayerNormalization +from fusion_simplified_layernorm import FusionSimplifiedLayerNormalization, FusionSkipSimplifiedLayerNormalization from fusion_utils import NumpyHelper from onnx import NodeProto, TensorProto, helper from onnx_model import OnnxModel @@ -56,8 +56,8 @@ def create_attention_node( Args: mask_index (str): mask input q_matmul (NodeProto): MatMul node in fully connection for Q - k_matmul (NodeProto): MatMul node in fully connection for K - v_matmul (NodeProto): MatMul node in fully connection for V + k_matmul (NodeProto): MatMul node in fully connection for K + v_matmul (NodeProto): MatMul node in fully connection for V num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning. hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning. input (str): input name @@ -111,7 +111,8 @@ def create_attention_node( name=attention_node_name + "_qkv_weight", data_type=TensorProto.FLOAT, dims=[qw_in_size, qkv_weight_dim], - vals=qkv_weight.flatten().tolist(), + vals=qkv_weight.tobytes(), + raw=True, ) self.model.add_initializer(weight, self.this_graph_name) @@ -665,7 +666,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): name=self.model.create_node_name("bias_table_weight", name_prefix=node_name_prefix), data_type=TensorProto.FLOAT, dims=[np.shape(table_weight)[0], np.shape(table_weight)[1]], - vals=table_weight_t.flatten().tolist(), + vals=table_weight_t.tobytes(), + raw=True, ) self.model.add_initializer(bias_table, self.this_graph_name) @@ -685,67 +687,6 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): self.node_name_to_graph_name[rpb_node.name] = self.this_graph_name -class FusionSimplifiedLayerNormalization(Fusion): - def __init__(self, model: OnnxModel): - super().__init__(model, "SimplifiedLayerNormalization", "Mul") - - def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): - if node.op_type != "Mul": - return - - sim_ln_nodes = self.model.match_parent_path( - node, - ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"], - [1, 1, 1, 0, 0, 0, 0], - ) - if sim_ln_nodes is None: - sim_ln_nodes = self.model.match_parent_path( - node, - ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Gather"], - [1, 1, 1, 0, 0, 0, 0], - ) - if sim_ln_nodes is None: - return - - pow_node = sim_ln_nodes[-2] - if self.model.find_constant_input(pow_node, 2.0) != 1: - return - - root_input = pow_node.input[0] - - mul_node_1 = sim_ln_nodes[0] - if root_input != mul_node_1.input[0]: - return - - second_add_node = sim_ln_nodes[3] - i, add_weight = self.model.get_constant_input(second_add_node) - if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4: - logger.warning(f"epsilon value is not expeced: {add_weight}") - return - - self.nodes_to_remove.extend(sim_ln_nodes[:-1]) - - normalize_node = helper.make_node( - "SimplifiedLayerNormalization", - inputs=[root_input, node.input[0]], - outputs=[node.output[0]], - name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="LayerNorm"), - ) - normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))]) - normalize_node.attribute.extend([helper.make_attribute("axis", int(-1))]) - normalize_node.attribute.extend([helper.make_attribute("stash_type", int(1))]) - self.nodes_to_add.append(normalize_node) - self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name - - -class FusionSkipSimplifiedLayerNormalization(FusionSkipLayerNormalization): - def __init__(self, model: OnnxModel): - super().__init__(model, "SkipSimplifiedLayerNormalization", "SimplifiedLayerNormalization") - - def fuse(self, node, input_name_to_nodes, output_name_to_node): - super().fuse(node, input_name_to_nodes, output_name_to_node) - - class T5OnnxModel(BertOnnxModel): def __init__(self, model, num_heads, hidden_size): super().__init__(model, num_heads, hidden_size) diff --git a/onnxruntime/python/tools/transformers/onnx_model_tnlr.py b/onnxruntime/python/tools/transformers/onnx_model_tnlr.py index d1815394e9..98235de6ba 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_tnlr.py +++ b/onnxruntime/python/tools/transformers/onnx_model_tnlr.py @@ -5,10 +5,9 @@ import logging from typing import Union -import numpy as np from fusion_attention import AttentionMask, FusionAttention from fusion_utils import NumpyHelper -from onnx import NodeProto, TensorProto, helper, numpy_helper +from onnx import NodeProto, helper from onnx_model import OnnxModel from onnx_model_bert import BertOnnxModel @@ -57,26 +56,24 @@ def create_attention_node( attention_node_name = self.model.create_node_name("Attention") + tensor_dtype = weight.data_type + np_type = helper.tensor_dtype_to_np_dtype(tensor_dtype) weight = helper.make_tensor( name=attention_node_name + "_qkv_weight", - data_type=TensorProto.FLOAT, + data_type=tensor_dtype, dims=[hidden_size, 3 * hidden_size], - vals=qkv_weight.flatten().tolist(), + vals=qkv_weight.astype(np_type).tobytes(), + raw=True, ) - - # Sometimes weights and bias are stored in fp16 - if weight.data_type == 10: - weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name)) self.model.add_initializer(weight, self.this_graph_name) bias = helper.make_tensor( name=attention_node_name + "_qkv_bias", - data_type=TensorProto.FLOAT, + data_type=tensor_dtype, dims=[3 * hidden_size], - vals=qkv_bias.flatten().tolist(), + vals=qkv_bias.astype(np_type).tobytes(), + raw=True, ) - if bias.data_type == 10: - bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name)) self.model.add_initializer(bias, self.this_graph_name) attention_inputs = [ diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py index 294641dd1e..4d15b9288e 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_unet.py +++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py @@ -12,6 +12,7 @@ from fusion_group_norm import FusionGroupNorm from fusion_nhwc_conv import FusionNhwcConv from fusion_options import FusionOptions +from fusion_skip_group_norm import FusionSkipGroupNorm from fusion_transpose import FusionInsertTranspose, FusionTranspose from onnx import ModelProto from onnx_model import OnnxModel @@ -57,8 +58,8 @@ def remove_useless_div(self): logger.info("Removed %d Div nodes", len(nodes_to_remove)) def convert_conv_to_nhwc(self): - # Do not update weight here since save external data has a bug - conv_to_nhwc_conv = FusionNhwcConv(self, update_weight=False) + # Transpose weights in offline might help since ORT does not apply constant-folding on Transpose nodes. + conv_to_nhwc_conv = FusionNhwcConv(self, update_weight=True) conv_to_nhwc_conv.apply() def merge_adjacent_transpose(self): @@ -150,6 +151,10 @@ def optimize(self, options: Optional[FusionOptions] = None): # Remove reshape nodes that having same shape of input and output based on symbolic shape inference. self.utils.remove_useless_reshape_nodes() + if (options is None) or options.enable_skip_group_norm: + skip_group_norm_fusion = FusionSkipGroupNorm(self) + skip_group_norm_fusion.apply() + if (options is None) or options.enable_bias_skip_layer_norm: # Fuse SkipLayerNormalization and Add Bias before it. self.fuse_add_bias_skip_layer_norm() @@ -181,6 +186,7 @@ def get_fused_operator_statistics(self): "SkipLayerNormalization", "BiasSplitGelu", "GroupNorm", + "SkipGroupNorm", "NhwcConv", "BiasAdd", ] diff --git a/onnxruntime/python/tools/transformers/onnx_model_vae.py b/onnxruntime/python/tools/transformers/onnx_model_vae.py index 9e79014e71..de8b59074a 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_vae.py +++ b/onnxruntime/python/tools/transformers/onnx_model_vae.py @@ -32,6 +32,7 @@ def get_fused_operator_statistics(self): ops = [ "Attention", "GroupNorm", + "SkipGroupNorm", "NhwcConv", ] for op in ops: diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 3f274eb6c8..b2d6423a45 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -96,7 +96,7 @@ def optimize_by_onnxruntime( logger.error("There is no gpu for onnxruntime to do optimization.") return onnx_model_path - model = OnnxModel(load_model(onnx_model_path, format=None, load_external_data=False)) + model = OnnxModel(load_model(onnx_model_path, load_external_data=False)) if model.use_float16() and not use_gpu: logger.warning( "This model uses float16 in the graph, use_gpu=False might cause extra Cast nodes. " @@ -510,11 +510,14 @@ def main(): if args.input_int32: optimizer.change_graph_inputs_to_int32() - if args.model_type in ["bert", "gpt2"]: - if optimizer.is_fully_optimized(): - logger.info("The model has been fully optimized.") - else: - logger.info("The model has been optimized.") + # Print the operator statistics might help end user. + optimizer.get_operator_statistics() + + fused_op_count = optimizer.get_fused_operator_statistics() + if "bert" in args.model_type and optimizer.is_fully_optimized(fused_op_count): + logger.info("The model has been fully optimized.") + else: + logger.info("The model has been optimized.") if args.convert_to_packing_mode: if args.model_type == "bert": diff --git a/onnxruntime/python/tools/transformers/shape_infer_helper.py b/onnxruntime/python/tools/transformers/shape_infer_helper.py index f8a5464d8a..f1fc0c952e 100644 --- a/onnxruntime/python/tools/transformers/shape_infer_helper.py +++ b/onnxruntime/python/tools/transformers/shape_infer_helper.py @@ -28,12 +28,12 @@ def __init__(self, model, verbose=0, int_max=2**31 - 1, auto_merge=True, guess_o self.is_inferred_: bool = False self.dynamic_axis_mapping_: Dict[str, int] = {} - def infer(self, dynamic_axis_mapping: Dict[str, int], max_runs: int = 128): + def infer(self, dynamic_axis_mapping: Dict[str, int], max_runs: int = 200): """Run shape inference, and try replace dynamic axis from string to integer when mapping is provided. Args: dynamic_axis_mapping (_type_): a dictionary with name of dynamic axis as key, like {"batch_size" : 4} - max_runs (int, optional): limit maximum number of runs to avoid infinite loop. Defaults to 32. + max_runs (int, optional): limit maximum number of runs to avoid infinite loop. Defaults to 200. Returns: bool: whether all shapes has been inferred or not. diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index 0e66a22e59..b652e0723f 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -2112,9 +2112,11 @@ static void RunModelWithRandomInput( constexpr int hidden_size = 768; constexpr int num_heads = 12; + const float min_value = is_float16 ? -0.001f : -1.0f; + const float max_value = is_float16 ? 0.001f : 1.0f; std::vector batch_input_dims{1, sequence_length, hidden_size}; - std::vector batch_input_data = random.Uniform(batch_input_dims, -1.0f, 1.0f); + std::vector batch_input_data = random.Uniform(batch_input_dims, min_value, max_value); std::vector input_dims{batch_size, sequence_length, hidden_size}; std::vector input_data; @@ -2123,12 +2125,12 @@ static void RunModelWithRandomInput( } std::vector weight_dims{hidden_size, 3 * hidden_size}; - std::vector weight_data = random.Uniform(weight_dims, -1.0f, 1.0f); + std::vector weight_data = random.Uniform(weight_dims, min_value, max_value); std::vector bias_dims{3 * hidden_size}; - std::vector bias_data = random.Uniform(bias_dims, -1.0f, 1.0f); + std::vector bias_data = random.Uniform(bias_dims, min_value, max_value); - float gpu_threshold = is_float16 ? static_cast(sequence_length) / 32.0f : 0.005f; + float gpu_threshold = is_float16 ? 0.5f : 0.005f; constexpr float cpu_threshold = 0.002f; bool enable_cuda = HasCudaEnvironment(is_float16 ? 530 : 0); bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); @@ -2146,7 +2148,10 @@ static void RunModelWithRandomInput( test.AddInput("weight", weight_dims, weight_data); test.AddInput("bias", bias_dims, bias_data); } - test.AddInput("mask_index", mask_index_dims, mask_index_data); + if (mask_index_data.size() > 0) { + test.AddInput("mask_index", mask_index_dims, mask_index_data); + } + test.AddReferenceOutputs(onnx_model, gpu_threshold); std::vector> execution_providers; if (enable_cuda) { @@ -2216,6 +2221,25 @@ TEST(AttentionTest, Attention_Mask1D_Fp32_B2_S64) { false); } +// This case can be used to test flash attention using Ampere GPU +TEST(AttentionTest, Attention_NoMask_Fp16) { + constexpr int batch_size = 2; + std::vector sequence_lengths{1, 7, 8}; + for (const auto& sequence_length : sequence_lengths) { + std::vector mask_index_dims{}; + std::vector mask_index_data{}; + std::string onnx_model = "testdata/attention_no_mask_fp16.onnx"; + + RunModelWithRandomInput( + batch_size, + sequence_length, + mask_index_dims, + mask_index_data, + onnx_model, + true); + } +} + // This test is disabled since it is flaky. TEST(AttentionTest, DISABLED_Attention_Mask1D_Fp16_B2_FusedNoPadding) { constexpr int batch_size = 2; diff --git a/onnxruntime/test/contrib_ops/bias_add_op_test.cc b/onnxruntime/test/contrib_ops/bias_add_op_test.cc index 7699f4479c..6fd091ef66 100644 --- a/onnxruntime/test/contrib_ops/bias_add_op_test.cc +++ b/onnxruntime/test/contrib_ops/bias_add_op_test.cc @@ -107,6 +107,20 @@ TEST(BiasAddTest, BiasAddTest_HiddenSize_1280) { constexpr int64_t num_channels = 1280; RunBiasAddTest(batch_size, image_size, num_channels); } + +TEST(BiasAddTest, BiasAddTest_HiddenSize_768) { + constexpr int64_t batch_size = 2; + constexpr int64_t image_size = 5; + constexpr int64_t num_channels = 768; + RunBiasAddTest(batch_size, image_size, num_channels); +} + +TEST(BiasAddTest, BiasAddTest_HiddenSize_1536) { + constexpr int64_t batch_size = 1; + constexpr int64_t image_size = 3; + constexpr int64_t num_channels = 1536; + RunBiasAddTest(batch_size, image_size, num_channels); +} #endif } // namespace test diff --git a/onnxruntime/test/contrib_ops/bias_split_gelu_op_test.cc b/onnxruntime/test/contrib_ops/bias_split_gelu_op_test.cc index db14eb3da4..a979717d23 100644 --- a/onnxruntime/test/contrib_ops/bias_split_gelu_op_test.cc +++ b/onnxruntime/test/contrib_ops/bias_split_gelu_op_test.cc @@ -152,6 +152,20 @@ TEST(BiasSplitGeluTest, BiasSplitGeluTest_HiddenSize_10240) { RunBiasSplitGeluTest(batch_size, sequence_length, hidden_size); } +TEST(BiasSplitGeluTest, BiasSplitGeluTest_HiddenSize_6144) { + constexpr int64_t batch_size = 2; + constexpr int64_t sequence_length = 3; + constexpr int64_t hidden_size = 6144; + RunBiasSplitGeluTest(batch_size, sequence_length, hidden_size); +} + +TEST(BiasSplitGeluTest, BiasSplitGeluTest_HiddenSize_12288) { + constexpr int64_t batch_size = 1; + constexpr int64_t sequence_length = 2; + constexpr int64_t hidden_size = 12288; + RunBiasSplitGeluTest(batch_size, sequence_length, hidden_size); +} + #endif } // namespace test diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc new file mode 100644 index 0000000000..dc8efbbaf3 --- /dev/null +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -0,0 +1,163 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef ORT_MINIMAL_BUILD + +#include "core/common/span_utils.h" +#include "core/framework/tensor.h" +#include "core/mlas/inc/mlas_q4.h" +#include "core/mlas/inc/mlas.h" +#include "core/session/inference_session.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/framework/test_utils.h" +#include "test/optimizer/graph_transform_test_builder.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" +#include "core/util/qmath.h" +#include "contrib_ops/cpu/quantization/dequantize_blockwise.h" + +#include +#include + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +namespace onnxruntime { +namespace test { + +void QuantizeDequantize(std::vector& raw_vals, + std::vector& quant_vals, + std::vector& scales, + std::vector* zp, + int32_t N, + int32_t K, + int32_t block_size) { + OrtThreadPoolParams to; + auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, + concurrency::ThreadPoolType::INTRA_OP); + contrib::QuantizeBlockwise( + quant_vals.data(), + raw_vals.data(), + scales.data(), + zp != nullptr ? zp->data() : nullptr, + block_size, + 4, + N, + K, + tp.get()); + + // Note that input1_f_vals is NxK after dequant + contrib::DequantizeBlockwise( + raw_vals.data(), + quant_vals.data(), + scales.data(), + zp != nullptr ? zp->data() : nullptr, + block_size, + 4, + N, + K, + tp.get()); +} + +void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zeropoint, bool use_float16) { + RandomValueGenerator random{1234}; + std::vector input0_vals(random.Gaussian(std::vector({M, K}), 0.0f, 0.25f)); + std::vector input1_f_vals(random.Gaussian(std::vector({K, N}), 0.0f, 0.25f)); + +#if 0 // for Debugging + std::vector input1_f_vals_trans(N * K); + MlasTranspose(input1_f_vals.data(), input1_f_vals_trans.data(), K, N); +#endif + + int64_t block_per_k = (K + block_size - 1) / block_size; + int64_t number_of_block = block_per_k * N; + int64_t block_blob_size = block_size * 4 / 8; + int64_t buf_size = number_of_block * (block_size * 4 / 8); + std::vector input1_vals(buf_size); + std::vector scales(number_of_block); + std::vector zp((N * block_per_k + 1) / 2); + + QuantizeDequantize(input1_f_vals, + input1_vals, + scales, + has_zeropoint ? &zp : nullptr, + static_cast(N), + static_cast(K), + static_cast(block_size)); + + std::vector expected_vals(M * N); + for (int64_t m = 0; m < M; m++) { + for (int64_t n = 0; n < N; n++) { + float sum = 0.0f; + for (int64_t k = 0; k < K; k++) { + sum += input0_vals[m * K + k] * input1_f_vals[n * K + k]; + } + expected_vals[m * N + n] = sum; + } + } + + OpTester test("MatMulNBits", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", block_size); + test.AddAttribute("bits", 4); + if (use_float16) { + test.AddInput("A", {M, K}, ToFloat16(input0_vals), false); + test.AddInput("B", {N, block_per_k, block_blob_size}, input1_vals, true); + test.AddInput("scales", {N * block_per_k}, ToFloat16(scales), true); + if (has_zeropoint) { + test.AddInput("zero_points", {(N * block_per_k + 1) / 2}, zp, true); + } + + test.AddOutput("Y", {M, N}, ToFloat16(expected_vals)); + test.SetOutputAbsErr("Y", 0.02f); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } else { + test.AddInput("A", {M, K}, input0_vals, false); + test.AddInput("B", {N, block_per_k, block_blob_size}, input1_vals, true); + test.AddInput("scales", {N * block_per_k}, scales, true); + if (has_zeropoint) { + test.AddInput("zero_points", {(N * block_per_k + 1) / 2}, zp, true); + } + + test.AddOutput("Y", {M, N}, expected_vals); + + test.Run(); + } +} + +TEST(MatMulNBits, Float32) { + for (auto M : {1, 2, 100}) { + for (auto N : {1, 2, 32, 288}) { + for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { + for (auto block_size : {16, 32, 64, 128}) { + RunTest(M, N, K, block_size, false, false); + RunTest(M, N, K, block_size, true, false); + } + } + } + } +} + +#if defined(USE_CUDA) +TEST(MatMulNBits, Float16) { + for (auto M : {1, 2, 100}) { + for (auto N : {1, 2, 32, 288}) { + for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { + for (auto block_size : {16, 32, 64, 128}) { + RunTest(M, N, K, block_size, false, true); + RunTest(M, N, K, block_size, true, true); + } + } + } + } +} + +#endif +} // namespace test +} // namespace onnxruntime + +#endif // ORT_MINIMAL_BUILD diff --git a/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc b/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc new file mode 100644 index 0000000000..e739b17d58 --- /dev/null +++ b/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef ORT_MINIMAL_BUILD + +#include "core/common/span_utils.h" +#include "core/framework/tensor.h" +#include "core/mlas/inc/mlas_q4.h" +#include "core/mlas/inc/mlas.h" +#include "core/session/inference_session.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/framework/test_utils.h" +#include "test/optimizer/graph_transform_test_builder.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" +#include "core/util/qmath.h" +#include "contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h" + +#include +#include + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +namespace onnxruntime { +namespace test { + +void QuantizeDequantizeBnb4(std::vector& raw_vals, // N X K + std::vector& quant_vals, + std::vector& absmax, + int32_t quant_type, + int32_t N, + int32_t K, + int32_t block_size) { + OrtThreadPoolParams to; + auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, + concurrency::ThreadPoolType::INTRA_OP); + + contrib::QuantizeBlockwiseBnb4( + quant_vals.data(), + raw_vals.data(), + absmax.data(), + block_size, + quant_type, + N, + K, + tp.get()); + + contrib::DequantizeBlockwiseBnb4( + raw_vals.data(), + quant_vals.data(), + absmax.data(), + block_size, + quant_type, + N, + K, + tp.get()); +} + +void RunTest(int64_t quant_type, int64_t M, int64_t N, int64_t K, int64_t block_size, bool use_float16) { + RandomValueGenerator random{1234}; + std::vector input0_vals(random.Gaussian(std::vector({M, K}), 0.0f, 0.25f)); + // quantizer expects transposed weights, N X K + std::vector input1_f_vals(random.Gaussian(std::vector({N, K}), 0.0f, 0.25f)); + + int64_t numel = N * K; + int64_t quantized_numel = (numel + 1) / 2; + int64_t total_block_count = (numel + block_size - 1) / block_size; + std::vector input1_vals(quantized_numel); + std::vector absmax(total_block_count); + + QuantizeDequantizeBnb4(input1_f_vals, + input1_vals, + absmax, + static_cast(quant_type), + static_cast(N), + static_cast(K), + static_cast(block_size)); + + std::vector expected_vals(M * N); + for (int64_t m = 0; m < M; m++) { + for (int64_t n = 0; n < N; n++) { + float sum = 0.0f; + for (int64_t k = 0; k < K; k++) { + sum += input0_vals[m * K + k] * input1_f_vals[n * K + k]; + } + expected_vals[m * N + n] = sum; + } + } + + OpTester test("MatMulBnb4", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", block_size); + test.AddAttribute("quant_type", quant_type); + if (use_float16) { + test.AddInput("A", {M, K}, ToFloat16(input0_vals), false); + test.AddInput("B", {quantized_numel}, input1_vals, true); + test.AddInput("absmax", {total_block_count}, ToFloat16(absmax), true); + + test.AddOutput("Y", {M, N}, ToFloat16(expected_vals)); + test.SetOutputAbsErr("Y", 0.02f); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } else { + test.AddInput("A", {M, K}, input0_vals, false); + test.AddInput("B", {quantized_numel}, input1_vals, true); + test.AddInput("absmax", {total_block_count}, absmax, true); + + test.AddOutput("Y", {M, N}, expected_vals); + + test.Run(); + } +} + +TEST(MatMulBnb4, Float32) { + for (auto qt : {0, 1}) { + for (auto M : {1, 2, 100}) { + for (auto N : {1, 2, 32, 288}) { + for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { + for (auto block_size : {16, 32, 64, 128}) { + RunTest(qt, M, N, K, block_size, false); + } + } + } + } + } +} + +#if defined(USE_CUDA) +TEST(MatMulBnb4, Float16) { + for (auto qt : {0, 1}) { + for (auto M : {1, 2, 100}) { + for (auto N : {1, 2, 32, 288}) { + for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { + for (auto block_size : {16, 32, 64, 128}) { + RunTest(qt, M, N, K, block_size, true); + } + } + } + } + } +} + +#endif +} // namespace test +} // namespace onnxruntime + +#endif // ORT_MINIMAL_BUILD diff --git a/onnxruntime/test/contrib_ops/matmul_fpq4_test.cc b/onnxruntime/test/contrib_ops/matmul_fpq4_test.cc index dd886ed1c6..09ae5eddb1 100644 --- a/onnxruntime/test/contrib_ops/matmul_fpq4_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_fpq4_test.cc @@ -24,7 +24,7 @@ namespace onnxruntime { namespace test { TEST(MatMulFpQ4, MatMul2DSym) { - // (100 x 41) X (41 x 288) + // (100 x 52) X (52 x 288) constexpr int64_t M = 100; constexpr int64_t N = 288; constexpr int64_t K = 52; diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index c2230501b0..49b338d832 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -300,6 +300,7 @@ static void RunMultiHeadAttentionKernel( if (kernel_type == AttentionKernelType::AttentionKernel_Default) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "0"}, {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, @@ -315,6 +316,7 @@ static void RunMultiHeadAttentionKernel( if (kernel_type == AttentionKernelType::AttentionKernel_Unfused) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, @@ -330,6 +332,7 @@ static void RunMultiHeadAttentionKernel( if (kernel_type == AttentionKernelType::AttentionKernel_TrtFusedCrossAttention) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, @@ -342,10 +345,11 @@ static void RunMultiHeadAttentionKernel( return; } -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION if (kernel_type == AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, @@ -362,6 +366,7 @@ static void RunMultiHeadAttentionKernel( if (kernel_type == AttentionKernelType::AttentionKernel_TrtFusedAttention) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, @@ -388,9 +393,9 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); } -#if USE_FLASH_ATTENTION - if (data.sequence_length >= contrib::attention::kMinSequenceLengthForMemoryEfficientAttentionFp32 || - data.kv_sequence_length >= contrib::attention::kMinSequenceLengthForMemoryEfficientAttentionFp32) { +#if USE_MEMORY_EFFICIENT_ATTENTION + if (data.sequence_length >= contrib::attention::kMinSeqLenForMemoryEfficientAttentionFp32 || + data.kv_sequence_length >= contrib::attention::kMinSeqLenForMemoryEfficientAttentionFp32) { kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention; if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( @@ -434,7 +439,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); } -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention; if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( diff --git a/onnxruntime/test/contrib_ops/packed_attention_op_test.cc b/onnxruntime/test/contrib_ops/packed_attention_op_test.cc index dd9224df8f..09baf8def0 100644 --- a/onnxruntime/test/contrib_ops/packed_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/packed_attention_op_test.cc @@ -433,7 +433,8 @@ static void RunModelWithRandomInput( std::vector token_offset_dims{batch_size, sequence_length}; std::vector cum_seq_len_dims{batch_size + 1}; - float gpu_threshold = is_float16 ? 0.1f : 0.005f; + float gpu_threshold = is_float16 ? 0.15f : 0.005f; + gpu_threshold *= sequence_length > 1024 ? 4.0f : 1.0f; // threshold should increase with sequence length bool enable_cuda = HasCudaEnvironment(is_float16 ? 530 : 0); if (enable_cuda) { OpTester test("PackedAttention", 1, onnxruntime::kMSDomain); diff --git a/onnxruntime/test/contrib_ops/packed_multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/packed_multihead_attention_op_test.cc index fc2b58680c..2225395556 100644 --- a/onnxruntime/test/contrib_ops/packed_multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/packed_multihead_attention_op_test.cc @@ -160,6 +160,7 @@ static void RunPackedMultiHeadAttentionTest( if (kernel_type == AttentionKernelType::AttentionKernel_TrtFusedAttention) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, @@ -168,10 +169,11 @@ static void RunPackedMultiHeadAttentionTest( InvokePackedMultiHeadAttentionTest(true, false); } -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION if (kernel_type == AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, @@ -182,9 +184,20 @@ static void RunPackedMultiHeadAttentionTest( } #endif +#if USE_FLASH_ATTENTION + if (kernel_type == AttentionKernelType::AttentionKernel_FlashAttention) { + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "0"}, + {onnxruntime::contrib::attention::kMinSeqLenForFlashAttentionPackedQKV, "0"}}}; + InvokePackedMultiHeadAttentionTest(true, true); + } +#endif + if (kernel_type == AttentionKernelType::AttentionKernel_Unfused) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, @@ -389,6 +402,32 @@ TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_cutlass) { AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention); } +#if USE_FLASH_ATTENTION +TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_FlashAttention) { + if (HasCudaEnvironment(800)) { + PackedAttentionTestData data; + GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias(data); + std::vector empty_data = {}; + + RunPackedMultiHeadAttentionTest( + data.qkv_data, + empty_data, + empty_data, + empty_data, + data.token_offset, + data.cumulative_sequence_length, + data.fp16_output_data, + data.batch_size, + data.sequence_length, + data.hidden_size, + data.v_hidden_size, + data.num_heads, + data.token_count, + AttentionKernelType::AttentionKernel_FlashAttention); + } +} +#endif + TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_unfused) { PackedAttentionTestData data; GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias(data); diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc new file mode 100644 index 0000000000..55f01bf0d3 --- /dev/null +++ b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc @@ -0,0 +1,641 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "gtest/gtest.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +static void RunTest( + const std::vector& input_data, + const std::vector& position_ids, + const std::vector& cos_cache, + const std::vector& sin_cache, + const std::vector& output_data, + int batch_size, + int sequence_length, + int head_size, + int num_heads, + int max_sequence_length, + int64_t interleaved, + bool use_float16, + bool disable_cpu, + bool disable_cuda, + bool disable_dml) { + // input : (batch_size, sequence_length, hidden_size) + // position ids : (1) or (batch_size, sequence_length) + // cos cache : (max_sequence_length, head_size / 2) + // sin cache : (max_sequence_length, head_size / 2) + // interleaved : 0 = false, 1 = true + + int hidden_size = num_heads * head_size; + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector pos_dims; + std::vector cache_dims = {max_sequence_length, head_size / 2}; + + assert(hidden_size != 0 && head_size != 0 && num_heads != 0 && max_sequence_length != 0); + assert(max_sequence_length >= sequence_length); + if (position_ids.size() == 1) { + pos_dims = {1}; + } else { + pos_dims = {batch_size, sequence_length}; + } + + std::string op_type = "RotaryEmbedding"; + std::vector> execution_providers; + + int min_cuda_architecture = use_float16 ? 530 : 0; + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml; + + if (enable_cuda && !disable_cuda) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + if (enable_dml && !disable_dml) { + execution_providers.push_back(DefaultDmlExecutionProvider()); + } + if (!use_float16 && !disable_cpu) { + execution_providers.push_back(DefaultCpuExecutionProvider()); + } + if (execution_providers.size() == 0) { + // Return early if CI pipeline does not support EP (e.g. CUDA EP for CPU CI pipeline) + return; + } + + OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); + test.AddAttribute("interleaved", interleaved); + + if (!use_float16) { + test.AddInput("input", input_dims, input_data); + test.AddInput("position_ids", pos_dims, position_ids); + test.AddInput("cos_cache", cache_dims, cos_cache); + test.AddInput("sin_cache", cache_dims, sin_cache); + test.AddOutput("output", input_dims, output_data); + } else { + test.AddInput("input", input_dims, ToFloat16(input_data)); + test.AddInput("position_ids", pos_dims, position_ids); + test.AddInput("cos_cache", cache_dims, ToFloat16(cos_cache)); + test.AddInput("sin_cache", cache_dims, ToFloat16(sin_cache)); + test.AddOutput("output", input_dims, ToFloat16(output_data)); + } + test.SetOutputAbsErr("output", 0.002f); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +static void RunTests(const std::vector& input_data, + const std::vector& position_ids, + const std::vector& cos_cache, + const std::vector& sin_cache, + const std::vector& output_data, + int batch_size, + int sequence_length, + int head_size = 0, + int num_heads = 0, + int max_sequence_length = 0, + int64_t interleaved = 0, + bool use_float16 = true) { + // FP32 test for CPU + RunTest(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved, + false, /* use_fp16 */ + false, /* disable_cpu */ + true, /* disable_cuda */ + true /* disable_dml */); + + // FP32 test for CUDA and DML + RunTest(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved, + false, /* use_fp16 */ + false, /* disable_cpu */ + false, /* disable_cuda */ + false /* disable_dml */); + + // FP16 test for CUDA and DML + if (use_float16) { + RunTest(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved, + true, /* use_fp16 */ + true, /* disable_cpu */ + false, /* disable_cuda*/ + false /* disable_dml */); + } +} + +// Interleaved = true, pos ids shape = (1) +TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_SmallData_LlamaMSFT) { + int batch_size = 1; + int sequence_length = 3; + int num_heads = 2; + int head_size = 4; + int max_sequence_length = 8; + int64_t interleaved = 1; // true + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -0.1320f, -0.2751f, -0.2350f, 0.0937f, + -1.2188f, 1.1676f, -1.0574f, -0.1188f, -0.7396f, -1.2425f, -0.1752f, 0.6990f, + -0.8110f, 0.6737f, -1.1233f, -0.0919f, -0.6861f, 0.7202f, 0.1963f, 0.6142f}; + + std::vector position_ids = {0}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 0.5403f, 0.9999f, -0.4161f, 0.9998f, -0.9900f, 0.9996f, + -0.6536f, 0.9992f, 0.2837f, 0.9988f, 0.9602f, 0.9982f, 0.7539f, 0.9976f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.8415f, 0.0100f, 0.9093f, 0.0200f, 0.1411f, 0.0300f, + -0.7568f, 0.0400f, -0.9589f, 0.0500f, -0.2794f, 0.0600f, 0.6570f, 0.0699f}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -0.1320f, -0.2751f, -0.2350f, 0.0937f, + -1.6411f, -0.3948f, -1.0561f, -0.1294f, 0.6460f, -1.2937f, -0.1822f, 0.6972f, + -0.2751f, -1.0178f, -1.1212f, -0.1143f, -0.3694f, -0.9235f, 0.1840f, 0.6180f}; + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved); +} + +// Interleaved = true, pos ids shape = (1) +TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_LargeData_LlamaMSFT) { + int batch_size = 2; + int sequence_length = 8; + int num_heads = 4; + int head_size = 6; + int max_sequence_length = 16; + int64_t interleaved = 1; // true + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, + 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f, + 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f, + 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f, + -0.1482f, 1.7376f, 2.2039f, -0.6589f, -1.0574f, + -0.1188f, -0.9078f, 0.3452f, -0.5713f, -0.2351f, + -0.5912f, 1.1312f, 0.7562f, -1.2023f, -0.5833f, + -0.4407f, 0.1766f, 1.0224f, -0.4826f, -0.5421f, + -0.5342f, -0.6413f, 1.3314f, -0.4498f, 0.5493f, + 0.0539f, 0.2601f, 0.8570f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -1.9791f, + 0.7787f, -0.7749f, -0.1398f, 1.1414f, -0.6354f, + 0.0352f, -0.4765f, -0.0409f, 1.1993f, 0.5374f, + -0.1930f, 2.5211f, -0.0452f, -0.3105f, -0.9407f, + -0.0034f, 1.5199f, -0.8480f, 0.5266f, 0.0299f, + -0.0498f, 1.0651f, 0.8860f, -1.4702f, -0.2134f, + -0.8707f, 1.6159f, -0.2356f, 0.9444f, 0.5937f, + 0.7203f, 0.5061f, 1.5192f, -0.4897f, 0.9231f, + 0.2654f, -0.1441f, 0.5407f, -1.5476f, 0.6455f, + -1.1382f, 0.4640f, -0.4986f, 0.1289f, 2.7631f, + 0.1405f, 1.1191f, 2.1134f, -0.9754f, 0.1757f, + -0.1319f, -0.2735f, 0.3355f, -0.6008f, -1.1164f, + 0.2577f, -0.7226f, -0.9244f, 1.8737f, 0.6052f, + 1.1904f, 1.2195f, -0.0470f, -1.0914f, 1.0223f, + 0.3152f, 1.7528f, -0.7650f, 1.8299f, -0.2784f, + -0.2719f, 0.1885f, 2.1432f, 0.8527f, 0.0965f, + -0.0625f, 0.8269f, 1.0122f, -1.4482f, -0.0644f, + 0.3215f, 0.5908f, -1.4197f, 0.2113f, 0.0306f, + 0.3604f, 0.3166f, -0.8975f, -0.6393f, -1.2944f, + -0.0243f, -0.2354f, -0.7087f, 1.1566f, 0.4296f, + 0.5599f, -0.7776f, 0.3339f, 0.1759f, 2.1108f, + 1.0702f, 0.8279f, -0.2969f, 0.7120f, -0.2068f, + -0.1548f, 0.1553f, 0.6207f, -0.1690f, -0.5816f, + 1.2632f, 0.0695f, 1.1862f, -1.1874f, -0.7468f, + -0.9320f, -0.8579f, -0.9647f, -0.0991f, 0.0195f, + 1.1213f, -1.4873f, -0.2043f, -1.0466f, -1.5772f, + -0.0489f, 0.3430f, 0.1264f, 0.1519f, -1.3639f, + -1.6593f, 1.8127f, -1.4459f, -0.2158f, -0.9792f, + -1.4392f, 0.6508f, 0.8964f, 0.5717f, -0.2390f, + 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f, + 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f, + -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f, + 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f, + 0.1527f, -0.5996f, -1.0962f, 1.6327f, 1.3951f, + 0.8784f, 0.3389f, 1.2907f, 0.3124f, 0.7299f, + 1.4220f, 0.3375f, 0.0438f, 1.8698f, -0.2635f, + -2.0799f, -0.6313f, 0.4090f, -1.1458f, 0.0784f, + -1.8848f, -1.6165f, 0.6179f, 0.9905f, -0.0729f, + 0.5054f, -0.6681f, -1.4382f, 1.7547f, -0.9605f, + -0.4558f, -1.6105f, 0.2979f, 1.1537f, -1.5604f, + 1.2779f, -1.2514f, 0.6056f, 0.5763f, -3.3558f, + 0.2836f, 0.6909f, -0.7631f, 2.4451f, -0.3500f, + 1.3289f, -0.6494f, 0.3478f, 1.0038f, -0.2937f, + 0.9238f, -1.2185f, 0.4138f, 0.5033f, 0.9174f, + 1.8131f, 1.4436f, -0.4207f, 0.0220f, -0.6807f, + -1.3306f, 1.5646f, 0.3338f, 0.7105f, 0.4683f, + -0.6179f, 0.0818f, -0.0488f, -0.9810f, -1.3632f, + 0.0929f, -1.7926f, -0.2921f, -0.4792f, 0.6756f, + -0.3413f, -0.2242f, -0.2111f, 0.6282f, 0.1667f, + -1.4055f, 1.5895f, 1.0838f, -0.9077f, -0.8060f, + 0.7967f, -2.9351f, 2.4179f, -0.4026f, 0.6451f, + 1.6845f, -0.0901f, 0.6106f, 2.3603f, 1.3908f, + -0.7917f, -0.6734f, -0.1213f, -1.1116f, -0.7401f, + -0.7879f, 0.0606f, -2.3337f, -1.2603f, -1.7245f, + -0.3533f, -0.9421f, -0.1776f, 0.3992f, -1.7142f, + -0.5319f, -0.8848f, 0.6513f, 1.0002f, -1.4699f, + -1.4254f, 0.7013f, 0.2414f, 0.2551f, -0.7457f, + 0.3133f, -1.0941f, -0.3682f, -0.0163f, -0.0645f, + -0.8101f, 0.1415f, 0.0551f, 0.5873f, -0.5887f, + -1.4733f, -0.8565f, 0.7400f, -0.5033f, 0.0553f, + 0.9265f, -0.8652f, -0.0288f, -0.2209f, 0.0610f, + 0.6776f, 0.4361f, -0.8052f, 0.3955f, 0.8988f, + 0.8238f, 0.2262f, 1.2912f, 0.6488f, 1.2114f, + 1.3569f, 0.2983f, 0.4718f, -1.1936f, 0.7928f, + -0.8665f, 0.9468f, 1.1629f, 0.0616f, -1.3136f, + -0.2764f, 0.0277f, -0.1126f, 0.2342f, -0.5866f, + -1.8219f, 1.1079f, 0.5795f, -1.4249f}; + + std::vector position_ids = {0}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 1.0000f, 0.5403f, 0.9989f, 1.0000f, -0.4161f, 0.9957f, + 1.0000f, -0.9900f, 0.9903f, 1.0000f, -0.6536f, 0.9828f, 1.0000f, 0.2837f, + 0.9732f, 0.9999f, 0.9602f, 0.9615f, 0.9999f, 0.7539f, 0.9477f, 0.9999f, + -0.1455f, 0.9318f, 0.9999f, -0.9111f, 0.9140f, 0.9998f, -0.8391f, 0.8942f, + 0.9998f, 0.0044f, 0.8725f, 0.9997f, 0.8439f, 0.8488f, 0.9997f, 0.9074f, + 0.8234f, 0.9996f, 0.1367f, 0.7962f, 0.9995f, -0.7597f, 0.7673f, 0.9995f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.0000f, 0.8415f, 0.0464f, 0.0022f, 0.9093f, 0.0927f, + 0.0043f, 0.1411f, 0.1388f, 0.0065f, -0.7568f, 0.1846f, 0.0086f, -0.9589f, + 0.2300f, 0.0108f, -0.2794f, 0.2749f, 0.0129f, 0.6570f, 0.3192f, 0.0151f, + 0.9894f, 0.3629f, 0.0172f, 0.4121f, 0.4057f, 0.0194f, -0.5440f, 0.4477f, + 0.0215f, -1.0000f, 0.4887f, 0.0237f, -0.5366f, 0.5286f, 0.0259f, 0.4202f, + 0.5675f, 0.0280f, 0.9906f, 0.6050f, 0.0302f, 0.6503f, 0.6413f, 0.0323f}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, + 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f, + 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f, + 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f, + -0.1482f, 1.7376f, 2.2039f, -0.6589f, -0.4713f, + -0.9540f, -0.9229f, 0.3027f, -0.5708f, -0.2363f, + -1.2713f, 0.1137f, 0.8112f, -1.1659f, -0.5824f, + -0.4419f, -0.7649f, 0.7011f, -0.4569f, -0.5639f, + -0.5328f, -0.6424f, 1.0979f, 0.8773f, 0.5462f, + 0.0793f, 0.2582f, 0.8576f, 0.2653f, 1.2295f, + -0.1839f, -0.4517f, -1.5052f, -0.4651f, 0.1155f, + -2.1237f, -0.7586f, -0.2110f, 1.1441f, -0.6304f, + 0.4186f, 0.2303f, -0.1519f, 1.1903f, 0.5382f, + -0.1906f, -1.0080f, 2.3112f, -0.2220f, -0.9655f, + -0.0099f, 1.5198f, 0.7652f, -0.6410f, 0.0365f, + -0.0452f, 1.0593f, 0.8929f, 1.4856f, 0.0038f, + -1.0865f, 1.4794f, -0.2417f, 0.9428f, -0.6894f, + -0.6293f, 0.2904f, 1.5747f, -0.4956f, 0.9199f, + -0.2424f, 0.1801f, 0.7503f, -1.4576f, 0.6529f, + -1.1340f, -0.6807f, -0.0252f, -0.3834f, 2.7394f, + 0.1308f, 1.1203f, -2.1196f, -0.9618f, 0.1970f, + -0.0972f, -0.2764f, 0.3332f, -0.4522f, 1.1844f, + 0.3867f, -0.6626f, -0.9405f, 1.8656f, 0.5053f, + -1.2361f, 1.2072f, 0.1789f, -1.1002f, 1.0129f, + 1.7702f, 0.1949f, -1.1653f, 1.6049f, -0.2755f, + -0.2749f, 2.1087f, 0.4272f, 0.8076f, 0.2900f, + -0.0714f, 0.8261f, -1.1016f, -1.3814f, -0.1366f, + 0.2981f, 0.6060f, -1.4132f, 0.0893f, -0.1939f, + 0.2779f, 0.3910f, -0.8906f, -0.6489f, -1.2496f, + 0.3383f, -0.0315f, -0.7461f, 1.1510f, 0.4445f, + 0.3203f, -0.9031f, 0.2727f, 0.2609f, 2.0968f, + 1.0974f, 0.7120f, -0.5164f, 0.7415f, -0.0031f, + -0.1568f, 0.1533f, 0.5487f, -0.3357f, -0.9064f, + 1.0546f, 0.0542f, 1.1870f, -0.4045f, -1.3431f, + -0.6094f, -1.1105f, -0.9631f, -0.1137f, -0.7219f, + 0.8582f, -1.3443f, -0.6684f, -1.0227f, -1.5929f, + -0.2622f, 0.2264f, 0.0713f, 0.1843f, -1.3387f, + -1.6797f, 2.3165f, 0.1009f, 0.1081f, -0.9969f, + -1.4488f, 0.6291f, 0.8964f, 0.5717f, -0.2390f, + 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f, + 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f, + -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f, + 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f, + 0.1527f, 0.5985f, -1.0968f, 1.5662f, 1.4693f, + 0.8776f, 0.3408f, 0.4345f, 1.2549f, 0.6631f, + 1.4543f, 0.3374f, 0.0445f, 1.2320f, 1.4311f, + -2.0483f, -0.7272f, 0.4114f, -1.1449f, 1.6283f, + -0.9524f, -1.6435f, 0.5422f, 0.9907f, -0.0708f, + 0.3972f, 0.7376f, -1.5947f, 1.6138f, -0.9586f, + -0.4600f, 0.3993f, -1.5884f, 1.2934f, -1.4467f, + 1.2833f, -1.2459f, -0.7760f, 0.3108f, -3.3677f, + -0.0287f, 0.6942f, -0.7601f, -0.6993f, 2.3690f, + 1.3834f, -0.5234f, 0.3435f, 1.0053f, 0.1604f, + -0.9560f, -1.2641f, 0.2406f, 0.4973f, 0.9206f, + -1.9987f, -1.1733f, -0.4197f, -0.0366f, -0.6720f, + -1.3350f, -1.5960f, -0.1097f, 0.6386f, 0.5624f, + -0.6184f, 0.0778f, 0.1867f, 0.9643f, -1.3629f, + -0.0972f, -1.7907f, -0.3037f, 0.8245f, -0.0789f, + -0.2940f, -0.2833f, -0.2165f, 0.6264f, -1.1726f, + 0.7926f, 1.3621f, 1.3586f, -0.9007f, -0.8138f, + -2.7421f, 1.3155f, 2.4507f, 0.0507f, 0.6305f, + 1.6900f, 0.5210f, -0.3309f, 2.0630f, 1.8026f, + -0.7859f, -0.6802f, -1.1003f, -0.1990f, -0.5391f, + -0.9370f, 0.0857f, -2.3330f, -2.0112f, 0.7193f, + -0.1272f, -0.9981f, -0.1818f, 0.3973f, -0.9963f, + 1.4929f, -1.0109f, 0.4304f, 1.0160f, -1.4590f, + 0.2682f, 1.5658f, 0.1762f, 0.3038f, -0.7491f, + 0.3052f, -1.1534f, -0.0478f, 0.0021f, -0.0665f, + -0.8118f, 0.1310f, 0.2171f, 0.5485f, -0.1610f, + -1.5784f, -0.8660f, 0.7289f, -0.4678f, 0.1937f, + 1.1287f, -0.5772f, -0.0259f, -0.2212f, 0.2479f, + 0.6336f, 0.6407f, -0.6543f, 0.3838f, 0.9039f, + 0.4724f, 0.7117f, 1.0165f, 1.0270f, 1.1908f, + 1.3750f, -0.0850f, 0.5517f, -1.3842f, 0.3703f, + -0.8806f, 0.9336f, 0.8362f, 0.8105f, -1.1566f, + -0.6813f, 0.0294f, -0.1122f, 0.5620f, -0.2884f, + -2.0803f, 0.4684f, 0.6009f, -1.4160f}; + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved); +} + +// Interleaved = false, pos ids shape = (1) +TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_LargeData_LlamaMSFT) { + int batch_size = 2; + int sequence_length = 8; + int num_heads = 4; + int head_size = 6; + int max_sequence_length = 16; + int64_t interleaved = 0; // false + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, + 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f, + 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f, + 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f, + -0.1482f, 1.7376f, 2.2039f, -0.6589f, -1.0574f, + -0.1188f, -0.9078f, 0.3452f, -0.5713f, -0.2351f, + -0.5912f, 1.1312f, 0.7562f, -1.2023f, -0.5833f, + -0.4407f, 0.1766f, 1.0224f, -0.4826f, -0.5421f, + -0.5342f, -0.6413f, 1.3314f, -0.4498f, 0.5493f, + 0.0539f, 0.2601f, 0.8570f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -1.9791f, + 0.7787f, -0.7749f, -0.1398f, 1.1414f, -0.6354f, + 0.0352f, -0.4765f, -0.0409f, 1.1993f, 0.5374f, + -0.1930f, 2.5211f, -0.0452f, -0.3105f, -0.9407f, + -0.0034f, 1.5199f, -0.8480f, 0.5266f, 0.0299f, + -0.0498f, 1.0651f, 0.8860f, -1.4702f, -0.2134f, + -0.8707f, 1.6159f, -0.2356f, 0.9444f, 0.5937f, + 0.7203f, 0.5061f, 1.5192f, -0.4897f, 0.9231f, + 0.2654f, -0.1441f, 0.5407f, -1.5476f, 0.6455f, + -1.1382f, 0.4640f, -0.4986f, 0.1289f, 2.7631f, + 0.1405f, 1.1191f, 2.1134f, -0.9754f, 0.1757f, + -0.1319f, -0.2735f, 0.3355f, -0.6008f, -1.1164f, + 0.2577f, -0.7226f, -0.9244f, 1.8737f, 0.6052f, + 1.1904f, 1.2195f, -0.0470f, -1.0914f, 1.0223f, + 0.3152f, 1.7528f, -0.7650f, 1.8299f, -0.2784f, + -0.2719f, 0.1885f, 2.1432f, 0.8527f, 0.0965f, + -0.0625f, 0.8269f, 1.0122f, -1.4482f, -0.0644f, + 0.3215f, 0.5908f, -1.4197f, 0.2113f, 0.0306f, + 0.3604f, 0.3166f, -0.8975f, -0.6393f, -1.2944f, + -0.0243f, -0.2354f, -0.7087f, 1.1566f, 0.4296f, + 0.5599f, -0.7776f, 0.3339f, 0.1759f, 2.1108f, + 1.0702f, 0.8279f, -0.2969f, 0.7120f, -0.2068f, + -0.1548f, 0.1553f, 0.6207f, -0.1690f, -0.5816f, + 1.2632f, 0.0695f, 1.1862f, -1.1874f, -0.7468f, + -0.9320f, -0.8579f, -0.9647f, -0.0991f, 0.0195f, + 1.1213f, -1.4873f, -0.2043f, -1.0466f, -1.5772f, + -0.0489f, 0.3430f, 0.1264f, 0.1519f, -1.3639f, + -1.6593f, 1.8127f, -1.4459f, -0.2158f, -0.9792f, + -1.4392f, 0.6508f, 0.8964f, 0.5717f, -0.2390f, + 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f, + 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f, + -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f, + 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f, + 0.1527f, -0.5996f, -1.0962f, 1.6327f, 1.3951f, + 0.8784f, 0.3389f, 1.2907f, 0.3124f, 0.7299f, + 1.4220f, 0.3375f, 0.0438f, 1.8698f, -0.2635f, + -2.0799f, -0.6313f, 0.4090f, -1.1458f, 0.0784f, + -1.8848f, -1.6165f, 0.6179f, 0.9905f, -0.0729f, + 0.5054f, -0.6681f, -1.4382f, 1.7547f, -0.9605f, + -0.4558f, -1.6105f, 0.2979f, 1.1537f, -1.5604f, + 1.2779f, -1.2514f, 0.6056f, 0.5763f, -3.3558f, + 0.2836f, 0.6909f, -0.7631f, 2.4451f, -0.3500f, + 1.3289f, -0.6494f, 0.3478f, 1.0038f, -0.2937f, + 0.9238f, -1.2185f, 0.4138f, 0.5033f, 0.9174f, + 1.8131f, 1.4436f, -0.4207f, 0.0220f, -0.6807f, + -1.3306f, 1.5646f, 0.3338f, 0.7105f, 0.4683f, + -0.6179f, 0.0818f, -0.0488f, -0.9810f, -1.3632f, + 0.0929f, -1.7926f, -0.2921f, -0.4792f, 0.6756f, + -0.3413f, -0.2242f, -0.2111f, 0.6282f, 0.1667f, + -1.4055f, 1.5895f, 1.0838f, -0.9077f, -0.8060f, + 0.7967f, -2.9351f, 2.4179f, -0.4026f, 0.6451f, + 1.6845f, -0.0901f, 0.6106f, 2.3603f, 1.3908f, + -0.7917f, -0.6734f, -0.1213f, -1.1116f, -0.7401f, + -0.7879f, 0.0606f, -2.3337f, -1.2603f, -1.7245f, + -0.3533f, -0.9421f, -0.1776f, 0.3992f, -1.7142f, + -0.5319f, -0.8848f, 0.6513f, 1.0002f, -1.4699f, + -1.4254f, 0.7013f, 0.2414f, 0.2551f, -0.7457f, + 0.3133f, -1.0941f, -0.3682f, -0.0163f, -0.0645f, + -0.8101f, 0.1415f, 0.0551f, 0.5873f, -0.5887f, + -1.4733f, -0.8565f, 0.7400f, -0.5033f, 0.0553f, + 0.9265f, -0.8652f, -0.0288f, -0.2209f, 0.0610f, + 0.6776f, 0.4361f, -0.8052f, 0.3955f, 0.8988f, + 0.8238f, 0.2262f, 1.2912f, 0.6488f, 1.2114f, + 1.3569f, 0.2983f, 0.4718f, -1.1936f, 0.7928f, + -0.8665f, 0.9468f, 1.1629f, 0.0616f, -1.3136f, + -0.2764f, 0.0277f, -0.1126f, 0.2342f, -0.5866f, + -1.8219f, 1.1079f, 0.5795f, -1.4249f}; + + std::vector position_ids = {0}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 1.0000f, 0.5403f, 0.9989f, 1.0000f, -0.4161f, 0.9957f, + 1.0000f, -0.9900f, 0.9903f, 1.0000f, -0.6536f, 0.9828f, 1.0000f, 0.2837f, + 0.9732f, 0.9999f, 0.9602f, 0.9615f, 0.9999f, 0.7539f, 0.9477f, 0.9999f, + -0.1455f, 0.9318f, 0.9999f, -0.9111f, 0.9140f, 0.9998f, -0.8391f, 0.8942f, + 0.9998f, 0.0044f, 0.8725f, 0.9997f, 0.8439f, 0.8488f, 0.9997f, 0.9074f, + 0.8234f, 0.9996f, 0.1367f, 0.7962f, 0.9995f, -0.7597f, 0.7673f, 0.9995f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.0000f, 0.8415f, 0.0464f, 0.0022f, 0.9093f, 0.0927f, + 0.0043f, 0.1411f, 0.1388f, 0.0065f, -0.7568f, 0.1846f, 0.0086f, -0.9589f, + 0.2300f, 0.0108f, -0.2794f, 0.2749f, 0.0129f, 0.6570f, 0.3192f, 0.0151f, + 0.9894f, 0.3629f, 0.0172f, 0.4121f, 0.4057f, 0.0194f, -0.5440f, 0.4477f, + 0.0215f, -1.0000f, 0.4887f, 0.0237f, -0.5366f, 0.5286f, 0.0259f, 0.4202f, + 0.5675f, 0.0280f, 0.9906f, 0.6050f, 0.0302f, 0.6503f, 0.6413f, 0.0323f}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, + 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f, + 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f, + 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f, + -0.1482f, 1.7376f, 2.2039f, -0.6589f, -0.8618f, + -0.0922f, -0.9073f, -0.7032f, -0.5762f, -0.2371f, + 0.6923f, 1.1571f, 0.7572f, -1.1471f, -0.5302f, + -0.4391f, 0.5516f, 1.0461f, -0.4812f, -0.1443f, + -0.4862f, -0.6423f, 0.6740f, -0.4614f, 0.5475f, + 1.1495f, 0.2389f, 0.8582f, -0.0259f, -0.6099f, + -0.2230f, 1.0963f, -1.5704f, -0.4595f, 0.9507f, + 0.6696f, -0.7721f, -1.7415f, 1.2087f, -0.6387f, + -1.1052f, -0.5243f, -0.0400f, -0.4671f, 0.4909f, + -0.1931f, -0.1937f, -0.0447f, -0.3171f, 2.6839f, + -0.0076f, 1.5185f, 0.8465f, 0.3737f, 0.0242f, + -0.0703f, 1.1279f, 0.8862f, 1.2275f, -0.1786f, + -0.8767f, -1.8072f, -0.2630f, 0.9387f, -0.8021f, + 0.7813f, 0.5001f, -1.4202f, -0.3850f, 0.9263f, + -0.0443f, -0.2323f, 0.5480f, 1.5696f, 0.6193f, + -1.1346f, 1.7878f, -0.5160f, 0.1192f, -2.1572f, + 0.0460f, 1.1202f, -1.4812f, -0.9082f, 0.1728f, + -1.5132f, -0.4489f, 0.3370f, -0.1541f, -0.9266f, + 0.2416f, 0.9270f, -1.1146f, 1.8758f, -0.4312f, + 1.3714f, 1.2106f, -0.4272f, -0.8529f, 1.0328f, + 1.8441f, 1.7698f, -0.7620f, 0.2168f, 0.1322f, + -0.2802f, 0.1460f, 2.1002f, 0.8437f, -0.1534f, + 0.4321f, 0.8360f, 0.5955f, -1.5452f, -0.0491f, + -0.8794f, 0.2418f, -1.4203f, 0.3635f, 0.2362f, + 0.3672f, -0.1128f, -0.8664f, -0.6354f, -1.4409f, + -0.3413f, -0.2409f, -0.3188f, 1.1054f, 0.4265f, + 0.5867f, -1.3279f, 0.3201f, 0.0125f, 1.8157f, + 1.0745f, 0.7372f, -0.2429f, 0.7100f, -0.4299f, + -0.2304f, 0.1645f, 0.9489f, -0.1816f, -0.5968f, + 1.0394f, 0.0204f, 1.1786f, -0.3315f, -0.3997f, + -0.9304f, -1.4268f, -1.1526f, -0.1132f, 0.1490f, + 1.3967f, -1.4634f, -0.1412f, -0.6339f, -1.5995f, + -0.1366f, 0.7604f, 0.1514f, 0.0824f, -1.1830f, + -1.6572f, 2.0099f, -0.9108f, -0.2256f, 0.4527f, + -1.8254f, 0.6475f, 0.8964f, 0.5717f, -0.2390f, + 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f, + 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f, + -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f, + 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f, + 0.1527f, -1.4979f, -1.1358f, 1.6320f, 0.2493f, + 0.8266f, 0.3424f, -0.4992f, 0.2964f, 0.7298f, + 1.8544f, 0.3516f, 0.0454f, 1.5415f, -0.2822f, + -2.0774f, 1.2323f, 0.3963f, -1.1503f, -0.4775f, + -1.9287f, -1.6164f, 0.3998f, 0.9020f, -0.0764f, + -1.8059f, -0.5762f, -1.4362f, -0.2706f, -1.0183f, + -0.4620f, 2.0891f, 0.1782f, 1.1591f, -0.8151f, + 1.3000f, -1.2464f, -0.5099f, 0.5098f, -3.3525f, + 0.4326f, 0.7414f, -0.7775f, -0.4271f, -0.3807f, + 1.3245f, 2.4936f, 0.3139f, 1.0095f, 0.2323f, + 0.8450f, -1.2244f, -0.4511f, 0.6266f, 0.9095f, + -1.7981f, 1.5241f, -0.4121f, 0.2341f, -0.4737f, + -1.3333f, -1.6150f, 0.4164f, 0.7100f, -0.2429f, + -0.5656f, 0.0863f, 0.0352f, -0.7227f, -1.3613f, + -0.0988f, -1.9114f, -0.3009f, 0.1435f, 0.7029f, + -0.3467f, 0.5092f, -0.0828f, 0.6253f, 0.7113f, + -1.2138f, 1.5964f, -0.8346f, -1.1515f, -0.7923f, + -0.8254f, -3.0038f, 2.4033f, -0.3398f, 0.0922f, + 1.7053f, 1.1114f, 0.7462f, 2.3660f, -0.8409f, + -0.6654f, -0.6530f, -0.7899f, -1.0957f, -0.7149f, + -0.1072f, -0.1967f, -2.3416f, -1.2609f, -1.6375f, + -0.3576f, 0.9413f, -0.5694f, 0.3954f, 0.1383f, + -0.7477f, -0.8689f, 1.8286f, 0.8510f, -1.4793f, + -0.1597f, 0.8541f, 0.2380f, 1.4392f, -0.5644f, + 0.3158f, -1.0686f, -0.1313f, -0.0181f, 0.2438f, + -0.8801f, 0.1413f, -0.3587f, 0.8002f, -0.5982f, + -1.4301f, -0.6620f, 0.7324f, -0.7250f, 0.0610f, + 0.9293f, -0.6902f, -0.0125f, -0.2089f, -0.1664f, + 0.5428f, 0.4245f, -0.7901f, 0.5665f, 0.9044f, + 0.1948f, -0.1723f, 1.2705f, 1.0303f, 1.2202f, + 1.3762f, -0.2959f, 0.7237f, -1.2077f, 0.7937f, + -0.6705f, 0.9287f, 1.0583f, 0.0496f, -1.3118f, + 0.5556f, 0.0459f, -0.1324f, -0.5513f, -0.7409f, + -1.8002f, 0.9892f, 0.3619f, -1.4522f}; + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved); +} + +// Interleaved = false, pos ids shape = (batch_size, sequence_length) +TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_SmallData_LlamaMSFT) { + int batch_size = 1; + int sequence_length = 2; + int num_heads = 3; + int head_size = 6; + int max_sequence_length = 4; + int64_t interleaved = 0; // false + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -0.8663f, -0.2656f, 0.1665f, 0.7911f, + -0.9320f, -0.8579f, -1.0574f, -0.1188f, -0.9078f, 0.3452f, -0.5713f, -0.2351f, + -0.8480f, 0.5266f, -1.2944f, -0.0243f, -0.2354f, -0.7087f, -0.9647f, -0.0991f, + -0.2994f, -0.0650f, -1.5720f, -1.3211f}; + + std::vector position_ids = {0, 1}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 1.0000f, 0.5403f, 0.9989f, 1.0000f, -0.4161f, 0.9957f, + 1.0000f, -0.9900f, 0.9903f, 1.0000f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.0000f, 0.8415f, 0.0464f, 0.0022f, 0.9093f, 0.0927f, 0.0043f, + 0.1411f, 0.1388f, 0.0065f}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -0.8663f, -0.2656f, 0.1665f, 0.7911f, + -0.9320f, -0.8579f, -0.8618f, -0.0922f, -0.9073f, -0.7032f, -0.5762f, -0.2371f, + -0.4377f, 0.5370f, -1.2929f, -0.7267f, -0.2107f, -0.7115f, -0.4666f, -0.0261f, + -0.2965f, -0.8469f, -1.5749f, -1.3217f}; + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc new file mode 100644 index 0000000000..326313fa09 --- /dev/null +++ b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc @@ -0,0 +1,282 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/framework/test_utils.h" +#include "test/providers/provider_test_utils.h" + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +using namespace std; + +namespace onnxruntime { +namespace test { + +TEST(SkipGroupNormTest, SkipGroupNorm_with_bias) { + constexpr int64_t B = 2; + constexpr int64_t C = 16; + constexpr int64_t H = 2; + constexpr int64_t W = 2; + + std::vector dims_nhwc{B, H, W, C}; + std::vector input_data_nhwc = { + -0.768555f, 1.575195f, -0.698242f, 1.587891f, 0.371826f, -0.280029f, -1.328125f, 0.127197f, + -0.197144f, 0.982422f, -0.671387f, -1.925781f, 1.800781f, -0.020218f, -0.782227f, 1.291992f, + -0.935059f, 1.782227f, -0.674316f, -1.943359f, -0.218994f, 0.054138f, -1.539062f, -0.546387f, + -2.160156f, 1.195312f, 1.653320f, -0.674316f, 0.224731f, -0.093262f, 1.160156f, -0.389404f, + 1.748047f, 0.766113f, 0.234375f, 0.011177f, -0.055847f, -0.930664f, -0.490234f, -0.655762f, + -0.382568f, -0.554688f, 0.910645f, -0.227295f, 1.687500f, 0.028397f, -0.241699f, -0.480957f, + -0.355713f, -2.095703f, -0.443359f, -0.126221f, -0.815918f, 0.792969f, -0.450439f, -0.952148f, + -1.174805f, 0.242798f, 0.138550f, -0.237061f, -0.994141f, 0.346436f, 0.147705f, 0.125854f, + -0.517090f, 0.253906f, 0.400146f, -0.540039f, -0.788574f, 0.146606f, -0.409668f, 0.281982f, + 1.444336f, 0.044434f, -0.366699f, 2.250000f, -0.453613f, -0.652344f, 1.828125f, -0.244751f, + 0.307129f, -0.051361f, 0.106384f, 0.844727f, 1.648438f, -0.904785f, -0.353760f, 0.510742f, + 0.074829f, -0.311279f, 0.274902f, 1.594727f, 1.367188f, 0.098755f, 0.043304f, -0.207397f, + 0.068298f, -0.601074f, 0.083008f, 0.264893f, -0.659180f, -0.216797f, -0.086548f, -0.683594f, + -0.964844f, -2.591797f, -0.817383f, -0.461914f, -1.840820f, -0.712402f, -0.052094f, -0.583008f, + 1.114258f, 0.190308f, 1.087891f, 0.005146f, 1.041992f, 1.363281f, -0.273682f, -0.465576f, + -0.027618f, 1.345703f, 0.789551f, -0.015991f, 0.401611f, 0.726562f, 0.598633f, 0.133667f}; + + std::vector gamma_data = { + 0.241255f, 0.556660f, -0.835532f, 0.564596f, -1.338308f, -0.278924f, 0.357326f, -1.745484f, + 0.277184f, 0.101415f, -0.018637f, -0.526188f, -0.011698f, -2.349411f, 0.206578f, 0.357679f}; + + std::vector beta_data = { + -1.194839f, 0.209146f, -0.677225f, -0.547338f, 1.275685f, -1.099577f, 0.470916f, 0.293907f, + -1.094209f, 2.350204f, -1.633769f, 0.248753f, -0.180166f, 0.365134f, -0.555731f, 1.843083f}; + + std::vector skip_data_nhwc = { + 0.892578f, -0.471924f, -0.423096f, 1.277344f, 0.257080f, -1.366211f, 1.552734f, 0.441406f, + -0.033142f, -0.059418f, 1.536133f, -0.225464f, 1.472656f, 0.591309f, -0.386230f, -2.197266f, + 0.089600f, -0.256592f, -1.873047f, 0.916992f, 0.392090f, 0.015526f, -0.949219f, 0.566895f, + -0.220459f, 1.262695f, -0.437744f, -2.283203f, -0.264893f, -0.660156f, 2.353516f, 1.992188f, + 0.865723f, -0.854004f, -1.014648f, 0.899414f, -1.041016f, 1.378906f, -0.075073f, -2.541016f, + -0.883789f, -0.428711f, 0.981934f, -0.072754f, 2.214844f, 0.658203f, 0.170166f, -1.727539f, + -0.672363f, -1.373047f, 0.318115f, 0.422363f, 0.260742f, -0.547852f, 0.545898f, -0.155762f, + 0.679688f, 2.861328f, -0.300781f, -0.504883f, 1.548828f, 0.353760f, -0.387695f, -1.595703f, + -0.170166f, -0.002897f, 0.273193f, -0.383545f, -1.082031f, -0.894043f, -1.048828f, -0.044708f, + 0.049286f, 0.220215f, 0.272705f, -0.853027f, -0.489258f, 0.513672f, 0.977051f, 0.310547f, + -0.577148f, -0.479004f, 0.838867f, 0.872559f, -0.510254f, 0.101807f, -0.299805f, -1.179688f, + -1.555664f, 0.668457f, 0.939453f, 0.118103f, -0.376709f, 0.735352f, -0.214233f, -1.987305f, + -0.931152f, 1.268555f, 1.427734f, -0.757812f, -1.324219f, 0.375488f, 1.364258f, -1.708008f, + 0.976562f, -0.037659f, -1.779297f, -0.196655f, 1.636719f, 0.690430f, 0.941895f, -1.882812f, + 0.431641f, 0.203857f, 1.306641f, -0.126343f, 1.408203f, 1.188477f, 0.432861f, -2.296875f, + -0.475342f, 1.517578f, -0.824219f, 1.288086f, -0.028244f, 1.918945f, 0.352295f, 0.693359f}; + + std::vector bias_data = { + -0.537598f, 0.500488f, -0.252441f, -0.460693f, -1.640625f, -1.298828f, 0.331787f, -1.588867f, + 1.000977f, 1.458984f, 0.702637f, 0.147827f, 1.143555f, 0.533691f, -0.072510f, 0.511230f}; + + std::vector norm_data_nhwc = { + -1.213867f, 0.856445f, -0.119141f, 0.386475f, 0.714355f, -0.804688f, + 1.048828f, -0.426270f, -1.091797f, 2.435547f, -1.641602f, 0.989746f, + -0.200928f, 0.267334f, -0.800781f, 1.577148f, -1.357422f, 1.000977f, + 0.613281f, -0.963867f, 1.179688f, -1.169922f, 0.308350f, 0.304199f, + -1.396484f, 2.513672f, -1.644531f, 1.206055f, -0.180664f, 1.896484f, + -0.294678f, 2.046875f, -0.844238f, 0.448486f, -0.294189f, -0.291504f, + 2.480469f, -1.250977f, 0.833008f, 4.593750f, -1.238281f, 2.335938f, + -1.651367f, 0.491943f, -0.204834f, 0.125610f, -0.682129f, 1.333984f, + -1.384766f, -0.708008f, -0.630859f, -0.504883f, 1.924805f, -1.208008f, + 1.013672f, 1.809570f, -1.128906f, 2.546875f, -1.631836f, 0.610840f, + -0.184326f, 0.110046f, -0.700195f, 1.471680f, -1.511719f, 0.492188f, + -0.847168f, -1.373047f, 2.837891f, -0.998047f, 0.521484f, 0.262207f, + -0.810547f, 2.400391f, -1.628906f, 0.049896f, -0.174927f, 1.076172f, + -0.252197f, 1.784180f, -1.418945f, 0.090820f, -1.056641f, 0.002945f, + 0.627441f, -0.989746f, 0.679199f, 1.130859f, -1.371094f, 2.408203f, + -1.645508f, -0.062988f, -0.192017f, -0.655762f, -0.718262f, 1.170898f, + -1.550781f, 0.706055f, -1.492188f, -1.148438f, 2.921875f, -1.136719f, + 1.058594f, 2.781250f, -1.089844f, 2.201172f, -1.597656f, 0.785645f, + -0.181396f, 0.868164f, -0.552246f, 1.097656f, -1.015625f, 0.565430f, + -2.173828f, -0.955078f, -0.336426f, -1.503906f, 0.838867f, 3.136719f, + -1.186523f, 2.580078f, -1.629883f, 0.094604f, -0.186523f, -3.884766f, + -0.542480f, 1.990234f}; + + std::vector add_out_data_nhwc = { + -0.414062f, 1.604492f, -1.374023f, 2.404297f, -1.011719f, -2.945312f, 0.556641f, -1.020508f, + 0.770508f, 2.382812f, 1.567383f, -2.003906f, 4.417969f, 1.105469f, -1.240234f, -0.394531f, + -1.382812f, 2.027344f, -2.800781f, -1.487305f, -1.466797f, -1.229492f, -2.156250f, -1.568359f, + -1.379883f, 3.917969f, 1.917969f, -2.808594f, 1.103516f, -0.219727f, 3.441406f, 2.113281f, + 2.076172f, 0.412598f, -1.033203f, 0.449951f, -2.738281f, -0.851562f, -0.233521f, -4.785156f, + -0.265625f, 0.475586f, 2.595703f, -0.152222f, 5.046875f, 1.220703f, -0.144043f, -1.697266f, + -1.566406f, -2.968750f, -0.377686f, -0.164551f, -2.195312f, -1.053711f, 0.427246f, -2.697266f, + 0.505859f, 4.562500f, 0.540527f, -0.594238f, 1.698242f, 1.233398f, -0.312500f, -0.958496f, + -1.224609f, 0.751465f, 0.420898f, -1.384766f, -3.511719f, -2.046875f, -1.126953f, -1.351562f, + 2.494141f, 1.724609f, 0.608398f, 1.544922f, 0.200684f, 0.395020f, 2.732422f, 0.577148f, + -0.807617f, -0.029785f, 0.692871f, 1.256836f, -0.502441f, -2.101562f, -0.321777f, -2.257812f, + -0.479492f, 1.816406f, 1.916992f, 1.860352f, 2.134766f, 1.367188f, -0.243408f, -1.683594f, + -1.400391f, 1.167969f, 1.257812f, -0.953613f, -3.625000f, -1.140625f, 1.609375f, -3.980469f, + 1.012695f, -1.170898f, -1.894531f, -0.510742f, 0.939453f, 0.511719f, 0.817383f, -1.955078f, + 1.007812f, 0.894531f, 2.142578f, -0.582031f, 0.809570f, 1.252930f, 0.490967f, -4.351562f, + 0.497803f, 4.320312f, 0.667969f, 1.419922f, 1.516602f, 3.179688f, 0.878906f, 1.337891f}; + + int min_cuda_architecture = 530; + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + + std::array channels_last_values = {-1, 1}; + + for (const int channels_last : channels_last_values) { + if (enable_cuda) { + std::vector> execution_providers; + if (enable_cuda && channels_last != 0) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + + // Don't run the test if no providers are supported + if (execution_providers.empty()) { + continue; + } + + OpTester test("SkipGroupNorm", 1, onnxruntime::kMSDomain); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("groups", 4); + test.AddAttribute("activation", 0); + + // We interpret channels_last==-1 as the attribute not being provided + if (channels_last != -1) { + test.AddAttribute("channels_last", channels_last); + } + + test.AddInput("X", dims_nhwc, ToFloat16(input_data_nhwc)); + test.AddInput("gamma", {C}, gamma_data); + test.AddInput("beta", {C}, beta_data); + test.AddInput("skip", dims_nhwc, ToFloat16(skip_data_nhwc)); + test.AddInput("bias", {C}, ToFloat16(bias_data)); + + constexpr float rel_error = 0.0f; + constexpr float abs_error = 0.02f; + test.AddOutput("Y", dims_nhwc, ToFloat16(norm_data_nhwc), false, rel_error, abs_error); + test.AddOutput("S", dims_nhwc, ToFloat16(add_out_data_nhwc), false, rel_error, abs_error); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + } +} + +TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) { + constexpr int64_t B = 1; + constexpr int64_t C = 64; + constexpr int64_t H = 1; + constexpr int64_t W = 1; + + std::vector dims_nhwc{B, H, W, C}; + std::vector input_data_nhwc = { + 0.588867f, 0.896484f, -0.213623f, 0.803223f, 0.659180f, -0.216187f, 1.197266f, -0.486084f, + -0.718750f, 0.332031f, -0.364746f, -0.831543f, -0.031219f, -1.059570f, 0.161621f, 1.519531f, + 0.169312f, 1.048828f, 1.330078f, 0.450195f, -2.867188f, -1.456055f, 0.708496f, -1.120117f, + -1.208984f, -1.199219f, -1.505859f, -0.549316f, 0.505371f, 0.723145f, -0.359131f, -0.250977f, + -0.879883f, -0.305664f, 0.709473f, 0.815430f, 0.617676f, -0.638672f, 0.066772f, -2.330078f, + -1.316406f, 1.744141f, 1.122070f, -0.633789f, -1.802734f, -0.825684f, 0.622559f, -0.481689f, + -1.364258f, -0.536621f, -0.464111f, 0.247437f, -0.213989f, 0.384521f, 0.556641f, -0.303711f, + -0.160034f, 0.882324f, -0.212036f, -0.796387f, 0.153076f, -1.311523f, 2.212891f, 0.685059f}; + + std::vector gamma_data = { + 0.789682f, 0.869051f, -0.010169f, -0.021685f, 0.506611f, 1.267444f, -0.312695f, 0.877844f, + 0.598637f, 0.598314f, -1.721544f, -0.593328f, 0.986705f, -0.419391f, -0.852584f, -0.572351f, + 0.912797f, -0.586863f, 0.477761f, -0.484418f, -0.193835f, 0.347757f, 0.327637f, -1.100304f, + 1.233108f, -0.272569f, -0.688656f, 0.687245f, 0.398386f, 0.888089f, -0.792587f, -0.769029f, + -0.427778f, 0.100768f, -2.187060f, 1.279301f, 1.109054f, 0.375992f, 1.514775f, 1.271436f, + 0.822896f, -0.476750f, 0.475507f, -1.011297f, 1.177197f, 1.586540f, -1.059944f, -0.145351f, + 0.841555f, -2.014113f, -0.230498f, 0.302128f, -0.180508f, 0.980534f, -0.126871f, 0.203151f, + -0.754841f, 0.420570f, -1.085798f, 1.335042f, -0.674930f, 2.453507f, 2.139259f, 1.087436f}; + + std::vector beta_data = { + -0.064518f, -0.262683f, 0.827528f, -0.960938f, 1.062519f, 2.417941f, 0.212789f, -1.638430f, + 1.875453f, -0.883058f, -0.006704f, 0.424894f, -0.869972f, 0.727008f, 0.879303f, -3.024141f, + -2.610873f, 1.269641f, 0.883006f, 0.804167f, -1.510324f, 2.258091f, -0.006750f, -1.553668f, + -1.659453f, 0.579603f, 0.652358f, 0.007077f, 0.099180f, 0.418658f, -0.273778f, -1.036199f, + -1.128691f, -0.296022f, -0.224056f, 1.476306f, 0.577624f, -0.372049f, -0.581659f, -1.841807f, + -0.361721f, 0.051160f, -0.749332f, -2.634807f, 0.562719f, -0.738667f, 0.024864f, -1.135937f, + -1.368144f, -1.458886f, -0.946683f, 1.953936f, -1.198661f, 0.166648f, 0.447206f, -0.458140f, + -0.553395f, 0.112900f, 0.255989f, -0.184551f, 1.254163f, -0.260479f, -1.232429f, 1.902575f}; + + std::vector skip_data = { + 0.952148f, 1.342773f, -0.172974f, -0.395264f, 1.119141f, 0.330566f, + 0.281494f, 0.472900f, -0.692871f, -0.634766f, 0.013504f, -1.866211f, + -0.428223f, 0.669922f, -0.323486f, 0.713867f, -0.350586f, 0.659180f, + -0.288574f, 0.324219f, -0.300781f, -0.789551f, -0.216431f, -0.221436f, + -0.086670f, 0.366211f, -0.643555f, -0.977051f, 0.001021f, 0.415527f, + -0.271729f, 0.836426f, 0.035370f, -0.806152f, 0.936035f, -0.021332f, + -1.095703f, 0.971680f, 1.648438f, 0.840820f, 0.837402f, 0.607910f, + -1.894531f, 0.666016f, -0.171143f, 1.625977f, -0.620117f, -0.039581f, + 1.702148f, -2.410156f, 1.565430f, -0.756348f, 1.446289f, 0.583496f, + -0.497559f, -0.271729f, -0.956055f, -1.642578f, 0.833496f, -1.136719f, + 1.248047f, -2.515625f, 0.080383f, 0.376221f}; + + std::vector norm_data_nhwc = { + 0.494873f, 1.017578f, 0.841797f, -0.949219f, 1.552734f, 1.333984f, 0.012703f, -2.511719f, + 1.424805f, -0.818359f, -0.128418f, 1.462891f, -0.882812f, 0.709961f, 0.693848f, -4.210938f, + -2.505859f, 0.513184f, 1.300781f, 0.460938f, -1.172852f, 1.851562f, 0.167969f, -0.885254f, + -2.535156f, 0.656738f, 1.683594f, -0.627441f, 0.478271f, 1.782227f, -0.196777f, -1.824219f, + -0.791016f, -0.398682f, -3.197266f, 2.275391f, 0.052704f, -0.286865f, 1.567383f, -3.552734f, + -0.646973f, -0.927734f, -1.032227f, -2.722656f, -1.337891f, 0.432129f, -0.040253f, -1.080078f, + -1.118164f, 3.123047f, -1.153320f, 1.843750f, -1.378906f, 0.941406f, 0.437256f, -0.542969f, + -0.218872f, 0.006115f, -0.265869f, -1.356445f, 0.649902f, -4.882812f, 1.696289f, 2.679688f}; + + std::vector add_out_data_nhwc = { + 1.541016f, 2.238281f, -0.386719f, 0.407959f, 1.778320f, 0.114380f, + 1.478516f, -0.013184f, -1.412109f, -0.302734f, -0.351318f, -2.697266f, + -0.459473f, -0.389648f, -0.161865f, 2.234375f, -0.181274f, 1.708008f, + 1.041016f, 0.774414f, -3.167969f, -2.246094f, 0.492188f, -1.341797f, + -1.295898f, -0.833008f, -2.148438f, -1.526367f, 0.506348f, 1.138672f, + -0.630859f, 0.585449f, -0.844727f, -1.111328f, 1.645508f, 0.793945f, + -0.478027f, 0.333008f, 1.714844f, -1.489258f, -0.479004f, 2.351562f, + -0.772461f, 0.032227f, -1.973633f, 0.800293f, 0.002441f, -0.521484f, + 0.337891f, -2.947266f, 1.101562f, -0.508789f, 1.232422f, 0.967773f, + 0.059082f, -0.575195f, -1.116211f, -0.760254f, 0.621582f, -1.933594f, + 1.401367f, -3.828125f, 2.292969f, 1.061523f}; + + int min_cuda_architecture = 530; + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + + std::array has_add_out_values = {true, false}; + std::array skip_dims = {2, 4}; + + constexpr int channels_last = 1; + for (const int skip_dim : skip_dims) { + for (const bool has_add_out : has_add_out_values) { + if (enable_cuda) { + std::vector> execution_providers; + if (enable_cuda && channels_last != 0) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + + // Don't run the test if no providers are supported + if (execution_providers.empty()) { + continue; + } + + OpTester test("SkipGroupNorm", 1, onnxruntime::kMSDomain); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("groups", 8); + test.AddAttribute("activation", 0); + test.AddAttribute("channels_last", channels_last); + + test.AddInput("X", dims_nhwc, ToFloat16(input_data_nhwc)); + test.AddInput("gamma", {C}, gamma_data); + test.AddInput("beta", {C}, beta_data); + if (skip_dim == 2) { + test.AddInput("skip", {B, C}, ToFloat16(skip_data)); + } else { + test.AddInput("skip", {B, 1, 1, C}, ToFloat16(skip_data)); + } + // no bias + + constexpr float rel_error = 0.0f; + constexpr float abs_error = 0.02f; + test.AddOutput("Y", dims_nhwc, ToFloat16(norm_data_nhwc), false, rel_error, abs_error); + + if (has_add_out) { + test.AddOutput("S", dims_nhwc, ToFloat16(add_out_data_nhwc), false, rel_error, abs_error); + } + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + } + } +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc index a41a1dd4ec..e8f2310666 100644 --- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -11,14 +11,15 @@ namespace onnxruntime { namespace test { constexpr float epsilon_ = 1e-12f; -static void RunTest( +static void RunOneTest( + bool strict, const std::vector& input_data, const std::vector& skip_data, const std::vector& gamma_data, const std::vector& beta_data, const std::vector& bias_data, const std::vector& output_data, - const std::vector& skip_input_bias_add_output_data, + const std::vector& sum_output_data, float epsilon, int batch_size, int sequence_length, @@ -27,7 +28,6 @@ static void RunTest( bool no_beta = false, bool simplified = false, bool use_token_count = false, - bool strict = false, bool broadcast_skip = false, bool no_batch_size = false) { // Input and output shapes @@ -82,14 +82,14 @@ static void RunTest( test.AddOutput("output", output_dims, output_data); - if (skip_input_bias_add_output_data.size() != 0) { + if (sum_output_data.size() != 0) { // The second and third outputs are reserved for something else test.AddOptionalOutputEdge(); test.AddOptionalOutputEdge(); test.AddOutput("skip_input_bias_add_output", output_dims, - skip_input_bias_add_output_data); + sum_output_data); } if (cpu_ep != nullptr) { @@ -117,14 +117,19 @@ static void RunTest( test.AddOutput("output", output_dims, ToFloat16(output_data)); - if (skip_input_bias_add_output_data.size() != 0) { + // Use larger threshold for fp16 + if (use_float16) { + test.SetOutputAbsErr("output", 0.01f); + } + + if (sum_output_data.size() != 0) { // The second and third outputs are reserved for something else test.AddOptionalOutputEdge(); test.AddOptionalOutputEdge(); test.AddOutput("skip_input_bias_add_output", output_dims, - ToFloat16(skip_input_bias_add_output_data)); + ToFloat16(sum_output_data)); } if (dml_ep != nullptr) { @@ -151,6 +156,36 @@ static void RunTest( } } +static void RunTest( + const std::vector& input_data, + const std::vector& skip_data, + const std::vector& gamma_data, + const std::vector& beta_data, + const std::vector& bias_data, + const std::vector& output_data, + const std::vector& sum_output_data, + float epsilon, + int batch_size, + int sequence_length, + int hidden_size, + bool use_float16 = false, + bool no_beta = false, + bool simplified = false, + bool use_token_count = false, + bool broadcast_skip = false, + bool no_batch_size = false) { + RunOneTest(false, input_data, skip_data, gamma_data, beta_data, bias_data, output_data, sum_output_data, + epsilon, batch_size, sequence_length, hidden_size, use_float16, no_beta, simplified, + use_token_count, broadcast_skip, no_batch_size); + + // strict mode does not support skip broadcasting. + if (!broadcast_skip) { + RunOneTest(true, input_data, skip_data, gamma_data, beta_data, bias_data, output_data, sum_output_data, + epsilon, batch_size, sequence_length, hidden_size, use_float16, no_beta, simplified, + use_token_count, broadcast_skip, no_batch_size); + } +} + TEST(SkipLayerNormTest, SkipLayerNormNullInput) { int batch_size = 1; int sequence_length = 0; @@ -359,8 +394,7 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch1_Float16_vec) { true /*use_float16*/, false /*no_beta*/, false /*simplified*/, - false /*use_token_count*/, - true /*strict*/); + false /*use_token_count*/); } TEST(SkipLayerNormTest, SkipLayerNormBatch1_NoBeta) { @@ -648,8 +682,7 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch1_Float16_vec_token_count) { true /*use_float16*/, false /*no_beta*/, false /*simplified*/, - true /*use_token_count*/, - true /*strict*/); + true /*use_token_count*/); } TEST(SkipLayerNormTest, SkipLayerNormBatch2_TokenCount) { @@ -774,13 +807,12 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_No_Batch_Size) { batch_size, sequence_length, hidden_size, - false, - false, - false, - false, - false, - false, - true); + false, // use_float16 + false, // no_beta + false, // simplified + false, // use_token_count + true, // broadcast_skip + true); // no_batch_size } TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_Batch_Size_1) { @@ -821,13 +853,12 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_Batch_Size_1) { batch_size, sequence_length, hidden_size, - false, - false, - false, - false, - false, - true, - false); + false, // use_float16 + false, // no_beta + false, // simplified + false, // use_token_count + true, // broadcast_skip + false); // no_batch_size } #endif diff --git a/onnxruntime/test/framework/float_8_test.cc b/onnxruntime/test/framework/float_8_test.cc new file mode 100644 index 0000000000..520d44f338 --- /dev/null +++ b/onnxruntime/test/framework/float_8_test.cc @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(DISABLE_FLOAT8_TYPES) + +#include + +#include "core/framework/float8.h" +#include "test/framework/test_utils.h" +#include "test/util/include/test/capturing_sink.h" +#include "test/util/include/test/test_environment.h" + +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +TEST(Float8_Tests, CastE4M3FN) { + std::vector> cases{ + std::pair(0.00439453125, 0.00390625), + std::pair(0.005859375, 0.005859375), + std::pair(0.005759375, 0.005859375), + std::pair(0.0046875, 0.00390625), + std::pair(0.001953125, 0.001953125), + std::pair(0.0029296875, 0.00390625), + std::pair(0.002053125, 0.001953125), + std::pair(0.00234375, 0.001953125), + std::pair(0.0087890625, 0.0078125), + std::pair(0.001171875, 0.001953125), + std::pair(1.8131605, 1.875)}; + for (auto it : cases) { + auto f8 = onnxruntime::Float8E4M3FN(it.first); + auto f8_32 = f8.ToFloat(); + EXPECT_EQ(it.second, f8_32); + } +} + +union float_bits { + uint32_t bits; + float val; +}; + +TEST(Float8_Tests, NanE4M3FN) { + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0x7F800000}).val).val, static_cast(0x7E)); + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0xFF800000}).val).val, static_cast(0xFE)); + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0x7F800000}).val, false).val, static_cast(0x7F)); + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0xFF800000}).val, false).val, static_cast(0xFF)); + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0x7F800001}).val).val, static_cast(0x7F)); + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0xFF800001}).val).val, static_cast(0xFF)); + // 0x7FC00000 is the value used by numpy. + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0x7FC00000}).val).val, static_cast(0x7F)); + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0xFFC00000}).val).val, static_cast(0xFF)); +} + +TEST(Float8_Tests, NanE4M3FNUZ) { + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0x7F800000}).val).val, static_cast(0x7F)); + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0xFF800000}).val).val, static_cast(0xFF)); + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0x7F800000}).val, false).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0xFF800000}).val, false).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0x7F800001}).val).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0xFF800001}).val).val, static_cast(0x80)); + // 0x7FC00000 is the value used by numpy. + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0x7FC00000}).val).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0xFFC00000}).val).val, static_cast(0x80)); +} + +TEST(Float8_Tests, NanE5M2) { + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0x7F800000}).val).val, static_cast(0x7B)); + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0xFF800000}).val).val, static_cast(0xFB)); + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0x7F800000}).val, false).val, static_cast(0x7C)); + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0xFF800000}).val, false).val, static_cast(0xFC)); + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0x7F800001}).val).val, static_cast(0x7F)); + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0xFF800001}).val).val, static_cast(0xFF)); + // 0x7FC00000 is the value used by numpy. + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0x7FC00000}).val).val, static_cast(0x7F)); + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0xFFC00000}).val).val, static_cast(0xFF)); +} + +TEST(Float8_Tests, NanE5M2FNUZ) { + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0x7F800000}).val).val, static_cast(0x7F)); + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0xFF800000}).val).val, static_cast(0xFF)); + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0x7F800000}).val, false).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0xFF800000}).val, false).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0x7F800001}).val).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0xFF800001}).val).val, static_cast(0x80)); + // 0x7FC00000 is the value used by numpy. + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0x7FC00000}).val).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0xFFC00000}).val).val, static_cast(0x80)); +} + +} // namespace test +} // namespace onnxruntime + +#endif // DISABLE_FLOAT8_TYPES diff --git a/onnxruntime/test/framework/ort_model_only_test.cc b/onnxruntime/test/framework/ort_model_only_test.cc index f8da4e8959..e2cb82e47f 100644 --- a/onnxruntime/test/framework/ort_model_only_test.cc +++ b/onnxruntime/test/framework/ort_model_only_test.cc @@ -4,6 +4,7 @@ #include "core/flatbuffers/schema/ort.fbs.h" #include "core/framework/data_types.h" #include "core/framework/tensorprotoutils.h" +#include "core/framework/TensorSeq.h" #include "core/graph/model.h" #include "core/graph/onnx_protobuf.h" #include "core/session/onnxruntime_cxx_api.h" @@ -556,6 +557,41 @@ TEST(OrtModelOnlyTests, LoadOrtFormatModelFromBufferNoCopyInitializersUseBuffer) RunOrtModel(test_info); } +// regression test for 2 issues covered by PR #17000 (internally reported issue). +// 1) allocation planner broke in minimal build when subgraph had no nodes. +// 2) usage of a sequence data type caused an exception due to IsSparseTensor() throwing +// instead of allowing the calling code to have #ifdef'd code to handle when IsSparseTensor +// returned true and sparse tensors were disabled. +TEST(OrtModelOnlyTests, GithubIssue17000) { + // need to run the model to + auto model_uri = ORT_TSTR("testdata/ort_github_issue_17000.ort"); + + auto allocator = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; + + OrtValue item0, item1; + CreateMLValue(allocator, {1}, {1.f}, &item0); + CreateMLValue(allocator, {2}, {2.f, 3.f}, &item1); + + auto elem_type = DataTypeImpl::GetType(); + auto tensor_seq = std::make_unique(elem_type); + tensor_seq->SetElements({item0, item1}); + + auto mltype = DataTypeImpl::GetType(); + OrtValue value(tensor_seq.release(), mltype, mltype->GetDeleteFunc()); + + OrtModelTestInfo test_info; + test_info.model_filename = model_uri; + test_info.inputs.insert(std::make_pair("seq_in", value)); + test_info.output_names = {"still_has_elements"}; + test_info.output_verifier = [](const std::vector& fetches) { + const auto& output = fetches[0].Get(); + ASSERT_EQ(output.Shape().Size(), 1); + ASSERT_EQ(output.Data()[0], true); // removed one item from seq so should still have elements + }; + + RunOrtModel(test_info); +} + #if !defined(DISABLE_ML_OPS) // test that we can deserialize and run a previously saved ORT format model // for a model with sequence and map outputs diff --git a/onnxruntime/test/optimizer/compute_optimizer_test.cc b/onnxruntime/test/optimizer/compute_optimizer_test.cc index 0101677428..a03d0da253 100644 --- a/onnxruntime/test/optimizer/compute_optimizer_test.cc +++ b/onnxruntime/test/optimizer/compute_optimizer_test.cc @@ -638,7 +638,8 @@ TEST(ComputeOptimizerTests, GatherMatMul_ScalarSlicingOnSecondLastDim) { std::map op_to_count = CountOpsInGraph(graph); onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), + TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); GraphViewer graph_viewer(graph); @@ -737,7 +738,8 @@ TEST(ComputeOptimizerTests, GatherMatMul_SlicingOnSecondLastDim) { std::map op_to_count = CountOpsInGraph(graph); onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), + TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); GraphViewer graph_viewer(graph); @@ -826,6 +828,345 @@ TEST(ComputeOptimizerTests, GatherMatMul_SlicingOnSecondLastDim) { } } +/* +Test graph includes multiple equivalent subgraphs as below. + graph input [2, 32, 256] (float) + | + LayerNormalization[axis=-1 (as example)] + | + [2, 32, 256] + | + | 0 (scalar) + | / + Gather[axis=1] + | + Identity + | + graph output [2, 256] (float) + +Add an Identity node because currently, we don't allow Gather generates graph output. +*/ +TEST(ComputeOptimizerTests, GatherLayerNormalization) { + std::vector> test_config_pairs{ + // { + // is_scalar_slice, + // ln_axis_before_propagation, + // expected_ln_axis_after_propagation, + // expected to propagate + // } + {true, 0, 0, false}, + {true, 1, 1, false}, + {true, 2, 1, true}, + {true, -3, -3, false}, + {true, -2, -2, false}, + {true, -1, 1, true}, + {false, 0, 0, false}, + {false, 1, 1, false}, + {false, 2, 2, true}, + {false, -3, -3, false}, + {false, -2, -2, false}, + {false, -1, -1, true}, + }; + + constexpr static int64_t gather_axis = 1; + constexpr static int64_t slice_data_value = 0; + + for (auto p : test_config_pairs) { + bool is_scalar_slice = std::get<0>(p); + int64_t ln_axis_before = std::get<1>(p); + int64_t ln_axis_after = std::get<2>(p); + bool expected_to_propagate = std::get<3>(p); + + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + + InlinedVector indices; + auto pre_graph_checker = [&indices](Graph& graph) -> Status { + auto op_count_pre = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_pre.size() == 3U); + TEST_RETURN_IF_NOT(op_count_pre["LayerNormalization"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Gather"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1); + + for (Node& node : graph.Nodes()) { + if (node.OpType() == "Gather") { + TEST_RETURN_IF_NOT(indices.empty()); + constexpr bool require_constant = true; + NodeArg* initializer_node_arg = graph.GetNodeArg(node.InputDefs()[1]->Name()); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, + indices, require_constant)); + } + } + return Status::OK(); + }; + + auto post_graph_checker = [is_scalar_slice, ln_axis_after, + &indices, expected_to_propagate](Graph& graph) { + auto op_count_post = CountOpsInGraph(graph); + + TEST_RETURN_IF_NOT(op_count_post.size() == 3U); + TEST_RETURN_IF_NOT(op_count_post["LayerNormalization"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Gather"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1); + + for (Node& node : graph.Nodes()) { + if (node.OpType() == "LayerNormalization") { + const auto& input_defs = node.InputDefs(); + + auto producer_node = graph.GetProducerNode(input_defs[0]->Name()); + if (expected_to_propagate) { + TEST_RETURN_IF_NOT(producer_node != nullptr); + TEST_RETURN_IF_NOT(producer_node->OpType() == "Gather"); + + InlinedVector values; + constexpr bool require_constant = true; + NodeArg* initializer_node_arg = graph.GetNodeArg(producer_node->InputDefs()[1]->Name()); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, + values, require_constant)); + for (size_t i = 0; i < values.size(); i++) { + TEST_RETURN_IF_NOT(values[i] == indices[i]); + } + + const ONNX_NAMESPACE::TensorShapeProto* slice_out_shape = producer_node->OutputDefs()[0]->Shape(); + TEST_RETURN_IF_NOT(slice_out_shape != nullptr); + + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + TEST_RETURN_IF_NOT(axis_value == ln_axis_after); + + if (is_scalar_slice) { + TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 2); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) && + slice_out_shape->dim(0).dim_value() == 2); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(1)) && + slice_out_shape->dim(1).dim_value() == 256); + } else { + TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 3); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) && + slice_out_shape->dim(0).dim_value() == 2); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(1)) && + slice_out_shape->dim(1).dim_value() == 1); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(2)) && + slice_out_shape->dim(2).dim_value() == 256); + } + + } else { + TEST_RETURN_IF_NOT(producer_node == nullptr); + } + } + } + + return Status::OK(); + }; + + auto build_test_case = [is_scalar_slice, ln_axis_before](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{2, 32, 256}}); + auto* input2_arg = builder.MakeInput({{256}}); + auto* input3_arg = builder.MakeInput({{256}}); + auto* ln_out = builder.MakeIntermediate(); + builder.AddNode("LayerNormalization", {input1_arg, input2_arg, input3_arg}, {ln_out}) + .AddAttribute("axis", ln_axis_before); + + std::vector slice_inputs; + NodeArg* indices_initializer = nullptr; + + if (is_scalar_slice) { + indices_initializer = builder.MakeScalarInitializer(slice_data_value); + } else { + indices_initializer = builder.MakeInitializer({1}, {slice_data_value}); + } + + slice_inputs = {ln_out, indices_initializer}; + + auto* gather_out = builder.MakeIntermediate(); + builder.AddNode("Gather", slice_inputs, + {gather_out}) + .AddAttribute("axis", gather_axis); + + auto* identity_out = builder.MakeOutput(); + builder.AddNode("Identity", {gather_out}, {identity_out}); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger, std::move(transformer), + TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); + } +} + +/* +Test graph includes multiple equivalent subgraphs as below. + graph input [2, 4, 32, 256] (float) + | + Softmax[axis=3 (as example)] + | + [2, 4, 32, 256] + | + | 0 (scalar) + | / + Gather[axis=1] + | + Identity + | + graph output [2, 32, 256] (float) + +Add an Identity node because currently, we don't allow Gather generates graph output. +*/ +TEST(ComputeOptimizerTests, GatherSoftmax) { + std::vector> test_config_pairs{ + // {is_scalar_slice, softmax_axis_before_propagation, + // expected_softmax_axis_after_propagation, expected to propagate} + {true, 0, 0, false}, + {true, 1, 1, false}, + {true, 2, 1, true}, + {true, 3, 2, true}, + {true, -4, -4, false}, + {true, -3, -3, false}, + {true, -2, 1, true}, + {true, -1, 2, true}, + {false, 0, 0, false}, + {false, 1, 1, false}, + {false, 2, 2, true}, + {false, 3, 3, true}, + {false, -4, -4, false}, + {false, -3, -3, false}, + {false, -2, -2, true}, + {false, -1, -1, true}, + }; + + constexpr static int64_t gather_axis = 1; + constexpr static int64_t slice_data_value = 0; + + for (auto p : test_config_pairs) { + bool is_scalar_slice = std::get<0>(p); + int64_t softmax_axis_before = std::get<1>(p); + int64_t softmax_axis_after = std::get<2>(p); + bool expected_to_propagate = std::get<3>(p); + + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + + InlinedVector indices; + auto pre_graph_checker = [&indices](Graph& graph) -> Status { + auto op_count_pre = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_pre.size() == 3U); + TEST_RETURN_IF_NOT(op_count_pre["Softmax"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Gather"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1); + + for (Node& node : graph.Nodes()) { + if (node.OpType() == "Gather") { + TEST_RETURN_IF_NOT(indices.empty()); + constexpr bool require_constant = true; + NodeArg* initializer_node_arg = graph.GetNodeArg(node.InputDefs()[1]->Name()); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, + indices, require_constant)); + } + } + return Status::OK(); + }; + + auto post_graph_checker = [is_scalar_slice, softmax_axis_after, + &indices, expected_to_propagate](Graph& graph) { + auto op_count_post = CountOpsInGraph(graph); + + TEST_RETURN_IF_NOT(op_count_post.size() == 3U); + TEST_RETURN_IF_NOT(op_count_post["Softmax"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Gather"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1); + + for (Node& node : graph.Nodes()) { + if (node.OpType() == "Softmax") { + const auto& input_defs = node.InputDefs(); + + auto producer_node = graph.GetProducerNode(input_defs[0]->Name()); + if (expected_to_propagate) { + TEST_RETURN_IF_NOT(producer_node != nullptr); + TEST_RETURN_IF_NOT(producer_node->OpType() == "Gather"); + + InlinedVector values; + constexpr bool require_constant = true; + NodeArg* initializer_node_arg = graph.GetNodeArg(producer_node->InputDefs()[1]->Name()); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, values, + require_constant)); + for (size_t i = 0; i < values.size(); i++) { + TEST_RETURN_IF_NOT(values[i] == indices[i]); + } + + const ONNX_NAMESPACE::TensorShapeProto* slice_out_shape = producer_node->OutputDefs()[0]->Shape(); + TEST_RETURN_IF_NOT(slice_out_shape != nullptr); + + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + TEST_RETURN_IF_NOT(axis_value == softmax_axis_after); + + if (is_scalar_slice) { + TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 3); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) && + slice_out_shape->dim(0).dim_value() == 2); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(1)) && + slice_out_shape->dim(1).dim_value() == 32); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(2)) && + slice_out_shape->dim(2).dim_value() == 256); + } else { + TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 4); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) && + slice_out_shape->dim(0).dim_value() == 2); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(1)) && + slice_out_shape->dim(1).dim_value() == 1); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(2)) && + slice_out_shape->dim(2).dim_value() == 32); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(3)) && + slice_out_shape->dim(3).dim_value() == 256); + } + + } else { + TEST_RETURN_IF_NOT(producer_node == nullptr); + } + } + } + + return Status::OK(); + }; + + auto build_test_case = [is_scalar_slice, softmax_axis_before](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{2, 4, 32, 256}}); + auto* softmax_out = builder.MakeIntermediate(); + builder.AddNode("Softmax", {input1_arg}, {softmax_out}) + .AddAttribute("axis", softmax_axis_before); + + std::vector slice_inputs; + + NodeArg* indices_initializer = nullptr; + + if (is_scalar_slice) { + indices_initializer = builder.MakeScalarInitializer(slice_data_value); + } else { + indices_initializer = builder.MakeInitializer({1}, {slice_data_value}); + } + + slice_inputs = {softmax_out, indices_initializer}; + + auto* gather_out = builder.MakeIntermediate(); + builder.AddNode("Gather", slice_inputs, + {gather_out}) + .AddAttribute("axis", gather_axis); + + auto* identity_out = builder.MakeOutput(); + builder.AddNode("Identity", {gather_out}, {identity_out}); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger, std::move(transformer), + TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); + } +} + TEST(ComputeOptimizerTests, GatherReshape_ScalarSlicingOnBatchDim) { const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); auto model_uri = MODEL_FOLDER "computation_reduction/gather/gather_reshape_scalar_batch_dim.onnx"; @@ -835,7 +1176,8 @@ TEST(ComputeOptimizerTests, GatherReshape_ScalarSlicingOnBatchDim) { std::map op_to_count = CountOpsInGraph(graph); onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), + TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); GraphViewer graph_viewer(graph); @@ -928,7 +1270,8 @@ TEST(ComputeOptimizerTests, GatherReshape_SlicingOnBatchDim) { std::map op_to_count = CountOpsInGraph(graph); onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), + TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); GraphViewer graph_viewer(graph); diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc index 6fd1f6081c..85ccb8f175 100644 --- a/onnxruntime/test/providers/checkers.cc +++ b/onnxruntime/test/providers/checkers.cc @@ -202,19 +202,19 @@ struct TensorCheck { // NOTE: Check isnan first to work around MSVC linker bug when /LTCG:incremental is specified. // If the isinf check is first the isnan check and branch gets omitted if (std::isnan(cur_expected[i])) { - ASSERT_TRUE(std::isnan(cur_actual[i])) << "Expected NaN. i:" << i; + EXPECT_TRUE(std::isnan(cur_actual[i])) << "Expected NaN. i:" << i; } else if (std::isinf(cur_expected[i])) { // Test infinity for equality - ASSERT_EQ(cur_expected[i], cur_actual[i]) << "Expected infinity. i:" << i; + EXPECT_EQ(cur_expected[i], cur_actual[i]) << "Expected infinity. i:" << i; } else { if (!has_abs_err && !has_rel_err) { // the default for existing tests - ASSERT_NEAR(cur_expected[i], cur_actual[i], threshold) << "i:" << i; + EXPECT_NEAR(cur_expected[i], cur_actual[i], threshold) << "i:" << i; } else { if (has_abs_err) { - ASSERT_NEAR(cur_expected[i], cur_actual[i], *(params.absolute_error)) << "i:" << i; + EXPECT_NEAR(cur_expected[i], cur_actual[i], *(params.absolute_error)) << "i:" << i; } if (has_rel_err) { - ASSERT_NEAR(cur_expected[i], cur_actual[i], *(params.relative_error) * std::abs(cur_expected[i])) + EXPECT_NEAR(cur_expected[i], cur_actual[i], *(params.relative_error) * std::abs(cur_expected[i])) << "i:" << i; } } @@ -256,20 +256,20 @@ void InternalNumericalCheck(const Tensor& expected, // NOTE: Check isnan first to work around MSVC linker bug when /LTCG:incremental is specified. // If the isinf check is first the isnan check and branch gets omitted if (std::isnan(cur_expected[i])) { - ASSERT_TRUE(std::isnan(cur_actual[i])) << "Expected NaN. i:" << i; + EXPECT_TRUE(std::isnan(cur_actual[i])) << "Expected NaN. i:" << i; } else if (std::isinf(cur_expected[i])) { // Test infinity for equality - ASSERT_EQ(cur_expected[i], cur_actual[i]) << "Expected infinity. i:" << i; + EXPECT_EQ(cur_expected[i], cur_actual[i]) << "Expected infinity. i:" << i; } else { if (!has_abs_err && !has_rel_err) { // the default for existing tests - ASSERT_NEAR(cur_expected[i], cur_actual[i], threshold) << "i:" << i; + EXPECT_NEAR(cur_expected[i], cur_actual[i], threshold) << "i:" << i; } else { if (has_abs_err) { - ASSERT_NEAR(cur_expected[i], cur_actual[i], *(params.absolute_error)) + EXPECT_NEAR(cur_expected[i], cur_actual[i], *(params.absolute_error)) << "i:" << i; } if (has_rel_err) { - ASSERT_NEAR(cur_expected[i], cur_actual[i], *(params.relative_error) * std::abs(cur_expected[i])) + EXPECT_NEAR(cur_expected[i], cur_actual[i], *(params.relative_error) * std::abs(cur_expected[i])) << "i:" << i; } } diff --git a/onnxruntime/test/providers/cpu/math/sign_test.cc b/onnxruntime/test/providers/cpu/math/sign_test.cc index 12844068c4..15b3f40faa 100644 --- a/onnxruntime/test/providers/cpu/math/sign_test.cc +++ b/onnxruntime/test/providers/cpu/math/sign_test.cc @@ -113,7 +113,7 @@ TestImpl(ForwardIter first, ForwardIter last, OutputIter out) { TEST(MathOpTest, Sign_uint64) { using namespace test_sign_internal; - OpTester test("Sign", 9); + OpTester test("Sign", 13); std::vector input_dims{7}; std::vector input; @@ -129,7 +129,7 @@ TEST(MathOpTest, Sign_uint64) { // we disable this test for openvino as openvino ep supports only FP32 Precision TEST(MathOpTest, Sign_int64) { using namespace test_sign_internal; - OpTester test("Sign", 9); + OpTester test("Sign", 13); std::vector input_dims{7}; std::vector input; @@ -146,7 +146,7 @@ TEST(MathOpTest, Sign_int64) { TEST(MathOpTest, Sign_float) { using namespace test_sign_internal; - OpTester test("Sign", 9); + OpTester test("Sign", 13); std::vector input_dims{7}; std::vector input; @@ -162,7 +162,7 @@ TEST(MathOpTest, Sign_float) { TEST(MathOpTest, Sign_double) { using namespace test_sign_internal; - OpTester test("Sign", 9); + OpTester test("Sign", 13); std::vector input_dims{7}; std::vector input; @@ -177,7 +177,7 @@ TEST(MathOpTest, Sign_double) { } TEST(MathOpTest, Sign_MLFloat16) { using namespace test_sign_internal; - OpTester test("Sign", 9); + OpTester test("Sign", 13); std::vector input_dims{7}; std::vector input; diff --git a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc index 8126990df5..ee18cf2cea 100644 --- a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/framework/tensor.h" +#include "core/providers/cpu/nn/batch_norm.h" // for BATCHNORM_INCLUDE_TRAINING_SUPPORT #include "core/session/inference_session.h" #include "test/common/dnnl_op_test_utils.h" #include "test/providers/provider_test_utils.h" @@ -846,7 +847,7 @@ TEST(BatchNormTest, BatchNorm2d_bfloat16) { #endif // USE_DNNL // TODO fix flaky test for CUDA -#ifdef ENABLE_TRAINING_OPS +#ifdef BATCHNORM_INCLUDE_TRAINING_SUPPORT TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { @@ -936,7 +937,7 @@ TEST(BatchNormTest, ForwardTrainingTestOpset15) { {kCudaExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); } -#endif // ENABLE_TRAINING_OPS +#endif // BATCHNORM_INCLUDE_TRAINING_SUPPORT } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/python/onnxruntime_test_float8.py b/onnxruntime/test/python/onnxruntime_test_float8.py index 3f3180230f..76ca5d9538 100644 --- a/onnxruntime/test/python/onnxruntime_test_float8.py +++ b/onnxruntime/test/python/onnxruntime_test_float8.py @@ -8,9 +8,11 @@ import unittest import numpy as np +import packaging.version as pv import parameterized from numpy.testing import assert_allclose from onnx import TensorProto +from onnx import __version__ as onnx_version from onnx.checker import check_model from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor, make_tensor_value_info from onnx.reference import ReferenceEvaluator @@ -37,7 +39,7 @@ class TestInferenceSession(unittest.TestCase): `_. """ - dtypes = frozenset({"FLOAT": np.float32, "FLOAT16": np.float16}) + dtypes = {"FLOAT": np.float32, "FLOAT16": np.float16} # noqa: RUF012 x = np.array( [0.4068359375, 352, 416, 336, 304, 272, -248, -100, 1e-4, 1e-2, 416, 432, 1e5, np.inf, -np.inf, np.nan], dtype=np.float32, @@ -76,7 +78,7 @@ class TestInferenceSession(unittest.TestCase): 240.0, 240.0, -240.0, - -104.0, + -96.0, 0.0, 0.009765625, 240.0, @@ -113,7 +115,7 @@ class TestInferenceSession(unittest.TestCase): [ 0.4375, 384.0, - 448.0, + 384.0, 320.0, 320.0, 256.0, @@ -121,7 +123,7 @@ class TestInferenceSession(unittest.TestCase): -96.0, 0.0001068115234375, 0.009765625, - 448.0, + 384.0, 448.0, 57344.0, 57344.0, @@ -167,7 +169,7 @@ class TestInferenceSession(unittest.TestCase): np.nan, np.nan, np.nan, - -104.0, + -96.0, 0.0, 0.009765625, np.nan, @@ -204,7 +206,7 @@ class TestInferenceSession(unittest.TestCase): [ 0.4375, 384.0, - 448.0, + 384.0, 320.0, 320.0, 256.0, @@ -212,7 +214,7 @@ class TestInferenceSession(unittest.TestCase): -96.0, 0.0001068115234375, 0.009765625, - 448.0, + 384.0, 448.0, np.nan, np.nan, @@ -245,6 +247,7 @@ def model_cast_cast_f16_float(self, to, saturate, rev=False): check_model(onnx_model) return onnx_model + @unittest.skipIf(pv.Version(onnx_version) < pv.Version("1.15.0"), reason="needs onnx>=1.15.0") @parameterized.parameterized.expand( [ ("FLOAT8E4M3FN", "FLOAT", 1), @@ -429,6 +432,7 @@ def model_qdq(self, to, float_name, saturate, castq=False, castdq=False, like=Fa check_model(onnx_model) return onnx_model + @unittest.skipIf(pv.Version(onnx_version) < pv.Version("1.15.0"), reason="needs onnx>=1.15.0") @parameterized.parameterized.expand( [ ("FLOAT8E4M3FN", "FLOAT", 1), @@ -689,6 +693,18 @@ def test_model_qdq_cuda_ortvalue(self, name: str, float_name: str, saturate: int self.assertEqual(expect.shape, y.shape) self.assertEqual(expect.dtype, y.dtype) + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running on CUDA.") + def test_compare_cpu_cuda_e4m3fn(self): + folder = os.path.join(os.path.dirname(__file__), "..", "testdata", "float8") + model = os.path.join(folder, "te.cast_fp8_1_fp32.onnx") + data = np.load(os.path.join(folder, "te.cast_fp8_1_fp32_input.npy")) + + sess_cpu = onnxruntime.InferenceSession(model, providers=["CPUExecutionProvider"]) + sess_cuda = onnxruntime.InferenceSession(model, providers=["CUDAExecutionProvider"]) + cpu_res = sess_cpu.run(None, {"input": data})[0] + cuda_res = sess_cuda.run(None, {"input": data})[0] + self.assertEqual(cuda_res.tolist(), cpu_res.tolist()) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index e554d41866..86577eaf2d 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -80,11 +80,7 @@ def test_model_serialization(self): so.log_severity_level = 1 so.logid = "TestModelSerialization" so.optimized_model_filepath = "./PythonApiTestOptimizedModel.onnx" - onnxrt.InferenceSession( - get_name("mul_1.onnx"), - sess_options=so, - providers=["CPUExecutionProvider"], - ) + onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=so) self.assertTrue(os.path.isfile(so.optimized_model_filepath)) os.remove(so.optimized_model_filepath) except Fail as onnxruntime_error: @@ -179,6 +175,62 @@ def test_model_serialization_with_original_external_initializers_to_directory(se else: raise onnxruntime_error + def test_model_serialization_with_original_external_initializers_to_current_directory(self): + optimized_model_filepath = "model_opt_with_ext_data_1.onnx" + external_initializers_file = "model_opt_with_ext_data_1.bin" + optimized_model_filepath_2 = "model_opt_with_ext_data_2.onnx" + external_initializers_file_2 = "model_opt_with_ext_data_2.bin" + + so = onnxrt.SessionOptions() + so.log_severity_level = 1 + so.logid = "TestModelSerializationWithOriginalExternalInitializersToCurrentDirectory" + so.optimized_model_filepath = optimized_model_filepath + + so.add_session_config_entry( + "session.optimized_model_external_initializers_file_name", external_initializers_file + ) + + # TODO(anyone): Set this to 100 will cause test error since some tensor below the threshold + # still refers to the original external data file. We shall fix this issue so that the + # optimized model only refers to one external data file. + so.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "10") + session1 = onnxrt.InferenceSession( + get_name("model_with_orig_ext_data.onnx"), sess_options=so, providers=["CPUExecutionProvider"] + ) + del session1 + self.assertTrue(os.path.isfile(optimized_model_filepath)) + self.assertTrue(os.path.isfile(external_initializers_file)) + + so2 = onnxrt.SessionOptions() + so2.log_severity_level = 1 + so2.logid = "TestModelSerializationWithExternalInitializersInCurrentDirectory" + so2.optimized_model_filepath = optimized_model_filepath_2 + so2.add_session_config_entry( + "session.optimized_model_external_initializers_file_name", external_initializers_file_2 + ) + so2.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "10") + + # verify that we can load the optimized model with external data in current directory and save + # optimized model with external data to current directory. + session2 = onnxrt.InferenceSession( + optimized_model_filepath, sess_options=so2, providers=["CPUExecutionProvider"] + ) + del session2 + self.assertTrue(os.path.isfile(optimized_model_filepath_2)) + self.assertTrue(os.path.isfile(external_initializers_file_2)) + + # Remove model 1 to make sure optimized model 2 can be loaded independently from model 1 + os.remove(optimized_model_filepath) + os.remove(external_initializers_file) + + session3 = onnxrt.InferenceSession( + optimized_model_filepath_2, sess_options=onnxrt.SessionOptions(), providers=["CPUExecutionProvider"] + ) + del session3 + + os.remove(optimized_model_filepath_2) + os.remove(external_initializers_file_2) + def test_get_providers(self): self.assertTrue("CPUExecutionProvider" in onnxrt.get_available_providers()) # get_all_providers() returns the default EP order from highest to lowest. diff --git a/onnxruntime/test/python/onnxruntime_test_python_iobinding.py b/onnxruntime/test/python/onnxruntime_test_python_iobinding.py index 8009d97ba3..56417f13fb 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_iobinding.py +++ b/onnxruntime/test/python/onnxruntime_test_python_iobinding.py @@ -16,40 +16,43 @@ from onnxruntime.capi._pybind_state import OrtValue as C_OrtValue from onnxruntime.capi._pybind_state import OrtValueVector, SessionIOBinding +test_params = [ + ("cuda", "CUDAExecutionProvider", C_OrtDevice.cuda), + ("dml", "DmlExecutionProvider", C_OrtDevice.dml), +] + class TestIOBinding(unittest.TestCase): - def create_ortvalue_input_on_gpu(self): + def _create_ortvalue_input_on_gpu(self, device): return onnxrt.OrtValue.ortvalue_from_numpy( - np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32), "cuda", 0 + np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32), device, 0 ) - def create_ortvalue_alternate_input_on_gpu(self): + def _create_ortvalue_alternate_input_on_gpu(self, device): return onnxrt.OrtValue.ortvalue_from_numpy( np.array([[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]], dtype=np.float32), - "cuda", + device, 0, ) - def create_uninitialized_ortvalue_input_on_gpu(self): - return onnxrt.OrtValue.ortvalue_from_shape_and_type([3, 2], np.float32, "cuda", 0) + def _create_uninitialized_ortvalue_input_on_gpu(self, device): + return onnxrt.OrtValue.ortvalue_from_shape_and_type([3, 2], np.float32, device, 0) - def create_numpy_input(self): + def _create_numpy_input(self): return np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) - def create_expected_output(self): + def _create_expected_output(self): return np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) - def create_expected_output_alternate(self): + def _create_expected_output_alternate(self): return np.array([[2.0, 8.0], [18.0, 32.0], [50.0, 72.0]], dtype=np.float32) def test_bind_input_to_cpu_arr(self): - self.create_numpy_input() - session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) io_binding = session.io_binding() # Bind Numpy object (input) that's on CPU to wherever the model needs it - io_binding.bind_cpu_input("X", self.create_numpy_input()) + io_binding.bind_cpu_input("X", self._create_numpy_input()) # Bind output to CPU io_binding.bind_output("Y") @@ -57,254 +60,280 @@ def test_bind_input_to_cpu_arr(self): # Invoke Run session.run_with_iobinding(io_binding) - # Sync if different CUDA streams + # Sync if different streams io_binding.synchronize_outputs() - # Get outputs over to CPU (the outputs which were bound to CUDA will get copied over to the host here) + # Get outputs over to CPU (the outputs which were bound to the GPU will get copied over to the host here) ort_output = io_binding.copy_outputs_to_cpu()[0] # Validate results - self.assertTrue(np.array_equal(self.create_expected_output(), ort_output)) + self.assertTrue(np.array_equal(self._create_expected_output(), ort_output)) - @unittest.skip("Could not find an implementation for Identity(19) node with name ''") def test_bind_input_types(self): - opset = onnx_opset_version() - devices = [ - ( - C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0), - ["CPUExecutionProvider"], - ) - ] - if "CUDAExecutionProvider" in onnxrt.get_all_providers(): - devices.append( - ( - C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0), - ["CUDAExecutionProvider"], - ) - ) - - for device, provider in devices: - for dtype in [ - np.float32, - np.float64, - np.int32, - np.uint32, - np.int64, - np.uint64, - np.int16, - np.uint16, - np.int8, - np.uint8, - np.float16, - np.bool_, - ]: - with self.subTest(dtype=dtype, device=str(device)): - x = np.arange(8).reshape((-1, 2)).astype(dtype) - proto_dtype = NP_TYPE_TO_TENSOR_TYPE[x.dtype] - - X = helper.make_tensor_value_info("X", proto_dtype, [None, x.shape[1]]) # noqa: N806 - Y = helper.make_tensor_value_info("Y", proto_dtype, [None, x.shape[1]]) # noqa: N806 - - # inference - node_add = helper.make_node("Identity", ["X"], ["Y"]) - - # graph - graph_def = helper.make_graph([node_add], "lr", [X], [Y], []) - model_def = helper.make_model( - graph_def, - producer_name="dummy", - ir_version=7, - producer_version="0", - opset_imports=[helper.make_operatorsetid("", opset)], - ) - - sess = onnxrt.InferenceSession(model_def.SerializeToString(), providers=provider) - - bind = SessionIOBinding(sess._sess) - ort_value = C_OrtValue.ortvalue_from_numpy(x, device) - bind.bind_ortvalue_input("X", ort_value) - bind.bind_output("Y", device) - sess._sess.run_with_iobinding(bind, None) - ortvaluevector = bind.get_outputs() - self.assertIsInstance(ortvaluevector, OrtValueVector) - ortvalue = bind.get_outputs()[0] - y = ortvalue.numpy() - assert_almost_equal(x, y) - - bind = SessionIOBinding(sess._sess) - bind.bind_input("X", device, dtype, x.shape, ort_value.data_ptr()) - bind.bind_output("Y", device) - sess._sess.run_with_iobinding(bind, None) - ortvalue = bind.get_outputs()[0] - y = ortvalue.numpy() - assert_almost_equal(x, y) + for device, execution_provider, generate_device in test_params: + with self.subTest(execution_provider): + if execution_provider not in onnxrt.get_available_providers(): + self.skipTest(f"Skipping on {device.upper()}.") + + opset = onnx_opset_version() + devices = [ + ( + C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0), + ["CPUExecutionProvider"], + ), + ( + C_OrtDevice(generate_device(), C_OrtDevice.default_memory(), 0), + [execution_provider], + ), + ] + + for inner_device, provider in devices: + for dtype in [ + np.float32, + np.float64, + np.int32, + np.uint32, + np.int64, + np.uint64, + np.int16, + np.uint16, + np.int8, + np.uint8, + np.float16, + np.bool_, + ]: + with self.subTest(dtype=dtype, inner_device=str(inner_device)): + x = np.arange(8).reshape((-1, 2)).astype(dtype) + proto_dtype = NP_TYPE_TO_TENSOR_TYPE[x.dtype] + + X = helper.make_tensor_value_info("X", proto_dtype, [None, x.shape[1]]) # noqa: N806 + Y = helper.make_tensor_value_info("Y", proto_dtype, [None, x.shape[1]]) # noqa: N806 + + # inference + node_add = helper.make_node("Identity", ["X"], ["Y"]) + + # graph + graph_def = helper.make_graph([node_add], "lr", [X], [Y], []) + model_def = helper.make_model( + graph_def, + producer_name="dummy", + ir_version=7, + producer_version="0", + opset_imports=[helper.make_operatorsetid("", opset)], + ) + + sess = onnxrt.InferenceSession(model_def.SerializeToString(), providers=provider) + + bind = SessionIOBinding(sess._sess) + ort_value = C_OrtValue.ortvalue_from_numpy(x, inner_device) + bind.bind_ortvalue_input("X", ort_value) + bind.bind_output("Y", inner_device) + sess._sess.run_with_iobinding(bind, None) + ortvaluevector = bind.get_outputs() + self.assertIsInstance(ortvaluevector, OrtValueVector) + ortvalue = bind.get_outputs()[0] + y = ortvalue.numpy() + assert_almost_equal(x, y) + + bind = SessionIOBinding(sess._sess) + bind.bind_input("X", inner_device, dtype, x.shape, ort_value.data_ptr()) + bind.bind_output("Y", inner_device) + sess._sess.run_with_iobinding(bind, None) + ortvalue = bind.get_outputs()[0] + y = ortvalue.numpy() + assert_almost_equal(x, y) def test_bind_input_only(self): - input = self.create_ortvalue_input_on_gpu() + for device, execution_provider, _ in test_params: + with self.subTest(execution_provider): + if execution_provider not in onnxrt.get_available_providers(): + self.skipTest(f"Skipping on {device.upper()}.") + input = self._create_ortvalue_input_on_gpu(device) - session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) - io_binding = session.io_binding() + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) + io_binding = session.io_binding() - # Bind input to CUDA - io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr()) + # Bind input to the GPU + io_binding.bind_input("X", device, 0, np.float32, [3, 2], input.data_ptr()) - # Sync if different CUDA streams - io_binding.synchronize_inputs() + # Sync if different streams + io_binding.synchronize_inputs() - # Bind output to CPU - io_binding.bind_output("Y") + # Bind output to CPU + io_binding.bind_output("Y") - # Invoke Run - session.run_with_iobinding(io_binding) + # Invoke Run + session.run_with_iobinding(io_binding) - # Sync if different CUDA streams - io_binding.synchronize_outputs() + # Sync if different streams + io_binding.synchronize_outputs() - # Get outputs over to CPU (the outputs which were bound to CUDA will get copied over to the host here) - ort_output = io_binding.copy_outputs_to_cpu()[0] + # Get outputs over to CPU (the outputs which were bound to the GPU will get copied over to the host + # here) + ort_output = io_binding.copy_outputs_to_cpu()[0] - # Validate results - self.assertTrue(np.array_equal(self.create_expected_output(), ort_output)) + # Validate results + self.assertTrue(np.array_equal(self._create_expected_output(), ort_output)) def test_bind_input_and_preallocated_output(self): - input = self.create_ortvalue_input_on_gpu() + for device, execution_provider, _ in test_params: + with self.subTest(execution_provider): + if execution_provider not in onnxrt.get_available_providers(): + self.skipTest(f"Skipping on {device.upper()}.") - session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) - io_binding = session.io_binding() - - # Bind input to CUDA - io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr()) - - # Bind output to CUDA - output = self.create_uninitialized_ortvalue_input_on_gpu() - io_binding.bind_output("Y", "cuda", 0, np.float32, [3, 2], output.data_ptr()) - - # Sync if different CUDA streams - io_binding.synchronize_inputs() - - # Invoke Run - session.run_with_iobinding(io_binding) + input = self._create_ortvalue_input_on_gpu(device) - # Sync if different CUDA streams - io_binding.synchronize_outputs() + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) + io_binding = session.io_binding() - # Get outputs over to CPU (the outputs which were bound to CUDA will get copied over to the host here) - ort_output_vals = io_binding.copy_outputs_to_cpu()[0] - # Validate results - self.assertTrue(np.array_equal(self.create_expected_output(), ort_output_vals)) + # Bind input to the GPU + io_binding.bind_input("X", device, 0, np.float32, [3, 2], input.data_ptr()) - # Validate if ORT actually wrote to pre-allocated buffer by copying the Torch allocated buffer - # to the host and validating its contents - ort_output_vals_in_cpu = output.numpy() - # Validate results - self.assertTrue(np.array_equal(self.create_expected_output(), ort_output_vals_in_cpu)) + # Bind output to the GPU + output = self._create_uninitialized_ortvalue_input_on_gpu(device) + io_binding.bind_output("Y", device, 0, np.float32, [3, 2], output.data_ptr()) - def test_bind_input_and_non_preallocated_output(self): - session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) - io_binding = session.io_binding() + # Sync if different streams + io_binding.synchronize_inputs() - # Bind input to CUDA - io_binding.bind_input( - "X", - "cuda", - 0, - np.float32, - [3, 2], - self.create_ortvalue_input_on_gpu().data_ptr(), - ) + # Invoke Run + session.run_with_iobinding(io_binding) - # Bind output to CUDA - io_binding.bind_output("Y", "cuda") + # Sync if different streams + io_binding.synchronize_outputs() - # Sync if different CUDA streams - io_binding.synchronize_inputs() + # Get outputs over to CPU (the outputs which were bound to the GPU will get copied over to the host + # here) + ort_output_vals = io_binding.copy_outputs_to_cpu()[0] + # Validate results + self.assertTrue(np.array_equal(self._create_expected_output(), ort_output_vals)) - # Invoke Run - session.run_with_iobinding(io_binding) + # Validate if ORT actually wrote to pre-allocated buffer by copying the allocated buffer + # to the host and validating its contents + ort_output_vals_in_cpu = output.numpy() + # Validate results + self.assertTrue(np.array_equal(self._create_expected_output(), ort_output_vals_in_cpu)) - # Sync if different CUDA streams - io_binding.synchronize_outputs() + def test_bind_input_and_non_preallocated_output(self): + for device, execution_provider, _ in test_params: + with self.subTest(execution_provider): + if execution_provider not in onnxrt.get_available_providers(): + self.skipTest(f"Skipping on {device.upper()}.") + + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) + io_binding = session.io_binding() + + input = self._create_ortvalue_input_on_gpu(device) + + # Bind input to the GPU + io_binding.bind_input( + "X", + device, + 0, + np.float32, + [3, 2], + input.data_ptr(), + ) - # This call returns an OrtValue which has data allocated by ORT on CUDA - ort_outputs = io_binding.get_outputs() - self.assertEqual(len(ort_outputs), 1) - self.assertEqual(ort_outputs[0].device_name(), "cuda") - # Validate results (by copying results to CPU by creating a Numpy object) - self.assertTrue(np.array_equal(self.create_expected_output(), ort_outputs[0].numpy())) - - # We should be able to repeat the above process as many times as we want - try once more - ort_outputs = io_binding.get_outputs() - self.assertEqual(len(ort_outputs), 1) - self.assertEqual(ort_outputs[0].device_name(), "cuda") - # Validate results (by copying results to CPU by creating a Numpy object) - self.assertTrue(np.array_equal(self.create_expected_output(), ort_outputs[0].numpy())) - - # Change the bound input and validate the results in the same bound OrtValue - # Bind alternate input to CUDA - io_binding.bind_input( - "X", - "cuda", - 0, - np.float32, - [3, 2], - self.create_ortvalue_alternate_input_on_gpu().data_ptr(), - ) + # Bind output to the GPU + io_binding.bind_output("Y", device) + + # Sync if different streams + io_binding.synchronize_inputs() + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # This call returns an OrtValue which has data allocated by ORT on the GPU + ort_outputs = io_binding.get_outputs() + self.assertEqual(len(ort_outputs), 1) + self.assertEqual(ort_outputs[0].device_name(), device) + # Validate results (by copying results to CPU by creating a Numpy object) + self.assertTrue(np.array_equal(self._create_expected_output(), ort_outputs[0].numpy())) + + # We should be able to repeat the above process as many times as we want - try once more + ort_outputs = io_binding.get_outputs() + self.assertEqual(len(ort_outputs), 1) + self.assertEqual(ort_outputs[0].device_name(), device) + # Validate results (by copying results to CPU by creating a Numpy object) + self.assertTrue(np.array_equal(self._create_expected_output(), ort_outputs[0].numpy())) + + input = self._create_ortvalue_alternate_input_on_gpu(device) + + # Change the bound input and validate the results in the same bound OrtValue + # Bind alternate input to the GPU + io_binding.bind_input( + "X", + device, + 0, + np.float32, + [3, 2], + input.data_ptr(), + ) - # Sync if different CUDA streams - io_binding.synchronize_inputs() + # Sync if different streams + io_binding.synchronize_inputs() - # Invoke Run - session.run_with_iobinding(io_binding) + # Invoke Run + session.run_with_iobinding(io_binding) - # Sync if different CUDA streams - io_binding.synchronize_outputs() + # Sync if different streams + io_binding.synchronize_outputs() - # This call returns an OrtValue which has data allocated by ORT on CUDA - ort_outputs = io_binding.get_outputs() - self.assertEqual(len(ort_outputs), 1) - self.assertEqual(ort_outputs[0].device_name(), "cuda") - # Validate results (by copying results to CPU by creating a Numpy object) - self.assertTrue(np.array_equal(self.create_expected_output_alternate(), ort_outputs[0].numpy())) + # This call returns an OrtValue which has data allocated by ORT on the GPU + ort_outputs = io_binding.get_outputs() + self.assertEqual(len(ort_outputs), 1) + self.assertEqual(ort_outputs[0].device_name(), device) + # Validate results (by copying results to CPU by creating a Numpy object) + self.assertTrue(np.array_equal(self._create_expected_output_alternate(), ort_outputs[0].numpy())) def test_bind_input_and_bind_output_with_ortvalues(self): - session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) - io_binding = session.io_binding() + for device, execution_provider, _ in test_params: + with self.subTest(execution_provider): + if execution_provider not in onnxrt.get_available_providers(): + self.skipTest(f"Skipping on {device.upper()}.") - # Bind ortvalue as input - input_ortvalue = self.create_ortvalue_input_on_gpu() - io_binding.bind_ortvalue_input("X", input_ortvalue) + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) + io_binding = session.io_binding() - # Bind ortvalue as output - output_ortvalue = self.create_uninitialized_ortvalue_input_on_gpu() - io_binding.bind_ortvalue_output("Y", output_ortvalue) + # Bind ortvalue as input + input_ortvalue = self._create_ortvalue_input_on_gpu(device) + io_binding.bind_ortvalue_input("X", input_ortvalue) - # Sync if different CUDA streams - io_binding.synchronize_inputs() + # Bind ortvalue as output + output_ortvalue = self._create_uninitialized_ortvalue_input_on_gpu(device) + io_binding.bind_ortvalue_output("Y", output_ortvalue) - # Invoke Run - session.run_with_iobinding(io_binding) + # Sync if different streams + io_binding.synchronize_inputs() - # Sync if different CUDA streams - io_binding.synchronize_outputs() + # Invoke Run + session.run_with_iobinding(io_binding) - # Inspect contents of output_ortvalue and make sure that it has the right contents - self.assertTrue(np.array_equal(self.create_expected_output(), output_ortvalue.numpy())) + # Sync if different streams + io_binding.synchronize_outputs() - # Bind another ortvalue as input - input_ortvalue_2 = self.create_ortvalue_alternate_input_on_gpu() - io_binding.bind_ortvalue_input("X", input_ortvalue_2) + # Inspect contents of output_ortvalue and make sure that it has the right contents + self.assertTrue(np.array_equal(self._create_expected_output(), output_ortvalue.numpy())) - # Sync if different CUDA streams - io_binding.synchronize_inputs() + # Bind another ortvalue as input + input_ortvalue_2 = self._create_ortvalue_alternate_input_on_gpu(device) + io_binding.bind_ortvalue_input("X", input_ortvalue_2) - # Invoke Run - session.run_with_iobinding(io_binding) + # Sync if different streams + io_binding.synchronize_inputs() - # Sync if different CUDA streams - io_binding.synchronize_outputs() + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() - # Inspect contents of output_ortvalue and make sure that it has the right contents - self.assertTrue(np.array_equal(self.create_expected_output_alternate(), output_ortvalue.numpy())) + # Inspect contents of output_ortvalue and make sure that it has the right contents + self.assertTrue(np.array_equal(self._create_expected_output_alternate(), output_ortvalue.numpy())) if __name__ == "__main__": diff --git a/onnxruntime/test/python/quantization/resnet_code.py b/onnxruntime/test/python/quantization/resnet_code.py new file mode 100644 index 0000000000..74e3652673 --- /dev/null +++ b/onnxruntime/test/python/quantization/resnet_code.py @@ -0,0 +1,13763 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import numpy +from onnx import numpy_helper +from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor_value_info, set_model_props + + +def create_model(): + initializers = [] + nodes = [] + inputs = [] + outputs = [] + functions = [] + + # opsets + opsets = {"": 13} + + # initializers + + list_value = [ + -0.013732648454606533, + -0.005861935671418905, + 0.06889285147190094, + -0.1172710582613945, + 0.08841240406036377, + -0.03748627379536629, + 0.016256270930171013, + -0.1059316024184227, + 0.08246039599180222, + 0.14295539259910583, + -0.32958757877349854, + 0.1631188541650772, + 0.05412565544247627, + -0.10758306831121445, + 0.12607362866401672, + -0.4987836182117462, + 0.7441706657409668, + -0.24774713814258575, + -0.30415549874305725, + 0.4033295810222626, + -0.13447114825248718, + 0.04623159021139145, + 0.2380414456129074, + -1.226112723350525, + 2.150630235671997, + -1.702580213546753, + 0.5305419564247131, + -0.06836353242397308, + -0.20055373013019562, + 0.7035881280899048, + -0.8389442563056946, + -0.1904432326555252, + 1.2609282732009888, + -1.0670661926269531, + 0.4142579436302185, + 0.04739700257778168, + -0.3265092074871063, + 1.1873037815093994, + -1.6817731857299805, + 0.9709527492523193, + -0.09095840901136398, + -0.12556785345077515, + 0.0835147574543953, + -0.24109329283237457, + 0.032948240637779236, + 0.46304041147232056, + -0.6594106554985046, + 0.349990576505661, + -0.04113377630710602, + 0.016451245173811913, + 0.008994563482701778, + -0.028321878984570503, + -0.05336569994688034, + 0.16036668419837952, + -0.12088149785995483, + 0.031160499900579453, + -0.0618649423122406, + 0.07205374538898468, + 0.15965768694877625, + -0.3389044404029846, + 0.21603335440158844, + 0.04029613360762596, + -0.0813034325838089, + 0.1019665077328682, + -0.4873599112033844, + 0.7873126268386841, + -0.2951086163520813, + -0.43754327297210693, + 0.5905176401138306, + -0.21821773052215576, + 0.06022067740559578, + 0.26326146721839905, + -1.6453089714050293, + 2.606400728225708, + -1.8939754962921143, + 0.5196341276168823, + 0.0055860355496406555, + -0.2335057258605957, + 0.9807199239730835, + -1.2137882709503174, + -0.2699125409126282, + 1.7379733324050903, + -1.4401814937591553, + 0.435971736907959, + -0.04829222336411476, + -0.24543480575084686, + 1.3292583227157593, + -2.0375823974609375, + 1.2458536624908447, + -0.08251484483480453, + -0.14181238412857056, + 0.10612589120864868, + -0.21671657264232635, + 0.1129523366689682, + 0.3666985034942627, + -0.7546612024307251, + 0.42979565262794495, + -0.0976259633898735, + -0.0008812264422886074, + 0.02994859404861927, + -0.07027778774499893, + 0.01393035613000393, + 0.07363647222518921, + -0.10249849408864975, + 0.06602989137172699, + -0.012129798531532288, + 0.10730132460594177, + -0.04546127840876579, + -0.16065146028995514, + 0.14788293838500977, + -0.05488971993327141, + 0.03601694852113724, + 0.07513345777988434, + -0.23953600227832794, + 0.48062530159950256, + -0.42057543992996216, + -0.02402813360095024, + 0.17920851707458496, + -0.10703158378601074, + -0.028666120022535324, + 0.2815375030040741, + -0.860264241695404, + 1.4422725439071655, + -1.2058128118515015, + 0.5272247791290283, + -0.06504356116056442, + -0.20021803677082062, + 0.44968947768211365, + -0.3856053650379181, + -0.1589551419019699, + 0.7579770684242249, + -0.8349987268447876, + 0.3225692808628082, + 0.08153475821018219, + -0.43163740634918213, + 0.8742384910583496, + -0.9722443222999573, + 0.579015851020813, + -0.06688100844621658, + -0.12384293973445892, + 0.08289378881454468, + -0.10082041472196579, + -0.11204896867275238, + 0.3934254050254822, + -0.4511864185333252, + 0.32745760679244995, + -0.06534548103809357, + -0.028830429539084435, + 0.021844232454895973, + 0.01775779016315937, + -0.004250001162290573, + 0.013087524101138115, + -0.001250433037057519, + -0.040545206516981125, + -0.014049320481717587, + -0.024194253608584404, + -0.023865194991230965, + -0.0038033330347388983, + 0.00920871365815401, + -0.006582418456673622, + 0.0032474950421601534, + -0.0369916632771492, + -0.16640843451023102, + -0.28968843817710876, + -0.3531132638454437, + -0.26307201385498047, + -0.13392697274684906, + -0.03747623786330223, + 0.08083077520132065, + 0.2026241272687912, + 0.25018608570098877, + 0.2529378831386566, + 0.2307336926460266, + 0.13928599655628204, + 0.08631229400634766, + 0.13893137872219086, + 0.4867081344127655, + 0.7170669436454773, + 0.8331555724143982, + 0.6734364032745361, + 0.3549460768699646, + 0.16798041760921478, + -0.14487245678901672, + -0.47733625769615173, + -0.7670150995254517, + -0.875726580619812, + -0.6291986703872681, + -0.2910463213920593, + -0.09991979598999023, + -0.009158087894320488, + 0.018850643187761307, + 0.02646111696958542, + -0.009077857248485088, + 0.029430989176034927, + -0.03707962855696678, + -0.05111744999885559, + -0.02076525054872036, + 0.011828843504190445, + 0.017857171595096588, + 0.02548048458993435, + -0.009077494964003563, + 0.0022066361270844936, + -0.02064262516796589, + -0.008582246489822865, + -0.022748643532395363, + -0.03038850985467434, + 0.0006585497176274657, + -0.0016039719339460135, + -0.01612498238682747, + 0.013966801576316357, + -0.05851661041378975, + -0.21422894299030304, + -0.33863192796707153, + -0.3720807433128357, + -0.3030800521373749, + -0.1737397164106369, + -0.05903157964348793, + 0.15018144249916077, + 0.27454254031181335, + 0.31182464957237244, + 0.30118387937545776, + 0.24605700373649597, + 0.14123573899269104, + 0.14992672204971313, + 0.20660799741744995, + 0.5046274662017822, + 0.7706091403961182, + 0.8978630900382996, + 0.7368614673614502, + 0.3929724097251892, + 0.23079657554626465, + -0.21169082820415497, + -0.5920398235321045, + -0.893406867980957, + -0.9499238729476929, + -0.730407178401947, + -0.3615736961364746, + -0.15422092378139496, + -0.024615347385406494, + 0.005115498788654804, + 0.024657316505908966, + 0.028517475351691246, + 0.027910854667425156, + -0.009482389315962791, + -0.042242538183927536, + -0.017875321209430695, + 0.00430292496457696, + 0.015949612483382225, + 0.003636278910562396, + -0.018156034871935844, + -0.0009349065367132425, + -0.0010362856555730104, + -0.013051170855760574, + -0.009141271002590656, + -8.714485738892108e-05, + 0.02399279735982418, + 0.01753612607717514, + -0.013710699044167995, + -0.014245252124965191, + -0.0028008236549794674, + -0.08206935226917267, + -0.1098734438419342, + -0.10250325500965118, + -0.08874496072530746, + -0.031079040840268135, + 0.004536658991128206, + 0.03923843801021576, + 0.08478657901287079, + 0.07715648412704468, + 0.018803801387548447, + 0.013921198435127735, + 0.015864359214901924, + 0.04947463795542717, + 0.039856068789958954, + 0.1712094396352768, + 0.362756609916687, + 0.4192918539047241, + 0.2668488621711731, + 0.11430513113737106, + 0.06648365408182144, + -0.058979276567697525, + -0.24177154898643494, + -0.3709423542022705, + -0.3979431986808777, + -0.29706764221191406, + -0.11569518595933914, + -0.01848490908741951, + -0.015523962676525116, + 0.05081642046570778, + 0.09057094901800156, + 0.08520761132240295, + 0.04497350752353668, + -0.019453801214694977, + -0.06109466031193733, + 0.011463015340268612, + -0.008522219955921173, + -0.005283404141664505, + -0.017313135787844658, + -0.0015744483098387718, + -0.011845857836306095, + -0.016727561131119728, + -0.006708915811032057, + 0.0008860539528541267, + -0.010050912387669086, + -0.028460539877414703, + -0.0165643822401762, + -0.016545938327908516, + -0.00567589420825243, + -0.0032017906196415424, + -0.0130555285140872, + -0.026848897337913513, + -0.02615198865532875, + 0.002669057110324502, + -0.027966763824224472, + -0.03851256147027016, + -0.014509409666061401, + -0.029059220105409622, + -0.007284109480679035, + 0.04045313969254494, + 0.10005538910627365, + 0.014574537053704262, + -0.044292762875556946, + -0.01750861294567585, + -0.02231375314295292, + -0.004432118032127619, + 0.10051869601011276, + 0.1443023532629013, + 0.0508832149207592, + -0.04350621998310089, + -0.0025447055231779814, + -0.014583000913262367, + -0.02153291553258896, + 0.018860718235373497, + 0.03618147224187851, + 0.007304056081920862, + -0.029104959219694138, + 0.00576505484059453, + -0.016025763005018234, + -0.025094063952565193, + -0.05296780914068222, + -0.037012189626693726, + -0.04414081946015358, + -0.053135257214307785, + -0.028890708461403847, + -0.010220452211797237, + -0.027575822547078133, + -0.01087758969515562, + -0.027209162712097168, + -0.030827227979898453, + -0.007646164856851101, + -0.016133273020386696, + 0.000639698002487421, + -0.0034172122832387686, + 0.03914793208241463, + 0.030786357820034027, + 0.005965455900877714, + 0.020923329517245293, + -0.03435938432812691, + -0.0026781477499753237, + 0.04278327897191048, + 0.20045910775661469, + 0.21770593523979187, + 0.09422573447227478, + 0.03198440372943878, + -0.021056609228253365, + 0.028007682412862778, + 0.19196027517318726, + 0.4791645109653473, + 0.5333831906318665, + 0.3014310598373413, + 0.103666290640831, + -0.03651479259133339, + 0.027079502120614052, + 0.19239209592342377, + 0.5168290138244629, + 0.5564895868301392, + 0.2977963089942932, + 0.07770062237977982, + -0.042239490896463394, + -0.017265107482671738, + 0.08760321140289307, + 0.2775075435638428, + 0.312491774559021, + 0.12284757196903229, + 0.019664151594042778, + -0.026643047109246254, + 0.0009152573184110224, + 0.016156431287527084, + 0.09042830765247345, + 0.08991760015487671, + 0.013326293788850307, + 0.02613811008632183, + 0.021025240421295166, + 0.0198842640966177, + 0.03375901281833649, + 0.028616728261113167, + 0.026605166494846344, + 0.04126269370317459, + 0.029309948906302452, + 0.01408455427736044, + -0.003831037785857916, + 0.01922326348721981, + -0.018229445442557335, + -0.013015883974730968, + 0.017597628757357597, + -0.007964612916111946, + 0.045263469219207764, + 0.0184696726500988, + -0.001163159729912877, + -0.1809321641921997, + -0.22486254572868347, + -0.08606110513210297, + 0.001087217591702938, + 0.037091098725795746, + -0.013625397346913815, + -0.178089901804924, + -0.5483279824256897, + -0.612791895866394, + -0.32531827688217163, + -0.06506585329771042, + 0.05076128616929054, + -0.007585812360048294, + -0.20981833338737488, + -0.6155760884284973, + -0.7119701504707336, + -0.354442298412323, + -0.04236743599176407, + 0.045713260769844055, + 0.03192479908466339, + -0.07216271013021469, + -0.310979425907135, + -0.3656359910964966, + -0.13522450625896454, + 0.008291869424283504, + 0.03362602740526199, + -0.0009240762447007, + 0.01604474149644375, + -0.055634208023548126, + -0.06180194392800331, + 0.0222025066614151, + 0.027704820036888123, + -0.034385330975055695, + -0.07050742954015732, + -0.06287489086389542, + 0.03521641716361046, + -0.00020920530369039625, + 0.05458284169435501, + 0.058752644807100296, + -0.08097169548273087, + -0.01668735221028328, + 0.18557283282279968, + 0.26208117604255676, + 0.1253771185874939, + 0.07758381962776184, + -0.022084739059209824, + 0.016727397218346596, + 0.23247942328453064, + 0.35444316267967224, + 0.21802566945552826, + -0.04409221559762955, + -0.08573070168495178, + -0.0994141548871994, + 0.07754423469305038, + 0.14311672747135162, + 0.04036660119891167, + -0.29222917556762695, + -0.38828015327453613, + -0.26185816526412964, + -0.12845511734485626, + 0.04763585329055786, + -0.017382778227329254, + -0.16010743379592896, + -0.2395028918981552, + -0.2049665004014969, + -0.041346337646245956, + 0.091490738093853, + -0.005191737785935402, + -0.07687077671289444, + -0.08105621486902237, + -0.05329642817378044, + -0.03404862806200981, + 0.11478845030069351, + 0.13328343629837036, + -0.037197597324848175, + -0.01787363924086094, + -0.016605347394943237, + 0.007853846065700054, + 0.029950136318802834, + 0.10808859020471573, + 0.02873288467526436, + -0.1766187697649002, + -0.17560969293117523, + -0.03922238200902939, + 0.14447443187236786, + 0.1534212827682495, + 0.11272227019071579, + 0.008810695260763168, + -0.1485181748867035, + 0.07839693129062653, + 0.43013128638267517, + 0.4898712635040283, + 0.26522761583328247, + 0.10202436149120331, + -0.07163076847791672, + 0.09933187812566757, + 0.47377726435661316, + 0.6340300440788269, + 0.36741772294044495, + -0.04812543839216232, + -0.17370514571666718, + -0.17513291537761688, + 0.22105705738067627, + 0.3226463794708252, + 0.09850790351629257, + -0.4044247269630432, + -0.6237908601760864, + -0.4679968059062958, + -0.1954391747713089, + 0.09878316521644592, + -0.004430827684700489, + -0.31550562381744385, + -0.5235733985900879, + -0.4510284662246704, + -0.13843706250190735, + 0.10064390301704407, + -0.006748788990080357, + -0.12714813649654388, + -0.2107744812965393, + -0.18755048513412476, + -0.05646044388413429, + 0.12781813740730286, + 0.18928050994873047, + -0.04337320104241371, + -0.04973407834768295, + -0.04690375551581383, + 0.0245530866086483, + 0.10698680579662323, + 0.1646823137998581, + 0.081840381026268, + -0.01471243891865015, + -0.03138890117406845, + -0.04195617139339447, + 0.012708203867077827, + 0.033312954008579254, + 0.02409377694129944, + -0.0036440726835280657, + -0.06239784508943558, + 0.0037516560405492783, + 0.11261500418186188, + 0.13069754838943481, + 0.05901307612657547, + 0.048614490777254105, + -0.027712708339095116, + 0.027247682213783264, + 0.19195327162742615, + 0.2688453793525696, + 0.1509387195110321, + 0.020540937781333923, + -0.004100556951016188, + -0.012650247663259506, + 0.039176344871520996, + 0.09037251025438309, + -0.004689970053732395, + -0.23859903216362, + -0.2364242821931839, + -0.15189304947853088, + -0.0761493444442749, + -0.0028172829188406467, + -0.04328106716275215, + -0.16187387704849243, + -0.21743592619895935, + -0.1282283067703247, + -0.024501819163560867, + 0.04029383510351181, + -0.027387680485844612, + -0.05414740741252899, + -0.08344019204378128, + -0.06591048091650009, + 0.012637111358344555, + 0.06905930489301682, + 0.08426016569137573, + -0.0030199100729078054, + 0.034059297293424606, + 0.01111840270459652, + 0.013492933474481106, + 0.0674189031124115, + 0.08242739737033844, + 0.006129032466560602, + -0.07763395458459854, + -0.03002289868891239, + -0.055725954473018646, + 0.008795201778411865, + 0.02994825504720211, + -0.06114519387483597, + -0.0560108907520771, + -0.008179228752851486, + -0.07149285078048706, + -0.02700420655310154, + -0.01306728646159172, + 0.06276566535234451, + 0.007125973701477051, + -0.03540417551994324, + -0.039717916399240494, + 0.009147526696324348, + -0.06517947465181351, + 0.0720859095454216, + -0.05035398155450821, + 0.06659520417451859, + -0.01841895841062069, + 0.004233633633702993, + -0.020911216735839844, + -0.004646372981369495, + 1.6690073013305664, + 0.4517613649368286, + -0.07667035609483719, + 0.005556757096201181, + -0.02638973295688629, + 0.044588603079319, + -0.020916732028126717, + 0.2571280598640442, + -0.009559552185237408, + -0.043380800634622574, + 0.03196016326546669, + -0.03783237189054489, + -0.03076902963221073, + 0.03180111199617386, + 0.06352709978818893, + 0.020281998440623283, + -0.00741154421120882, + -0.0009214285528287292, + -0.0476187989115715, + -0.07208544760942459, + -0.05323023349046707, + -0.011103631928563118, + 0.02877136506140232, + -0.05324484035372734, + -0.10076326876878738, + 0.026193000376224518, + 0.03536469116806984, + 0.045722659677267075, + -0.03756006807088852, + 0.022998394444584846, + 0.0019359687576070428, + 0.01654801517724991, + 0.047304198145866394, + -0.08431598544120789, + -0.0645647644996643, + -0.17326746881008148, + -0.10692577064037323, + -0.08416426181793213, + -0.04107839986681938, + -0.0012680464424192905, + -0.02600814774632454, + -0.014215772971510887, + 0.2114446610212326, + -0.040954578667879105, + -0.05050172284245491, + 0.004194092936813831, + -0.0025900816544890404, + -0.1359374076128006, + 0.03946976363658905, + 2.3023669719696045, + 0.7484877109527588, + -0.1994970589876175, + -0.06490366160869598, + 0.007983183488249779, + -0.017937449738383293, + -0.12516839802265167, + 0.3313288688659668, + 0.11946671456098557, + -0.16942338645458221, + -0.007721045054495335, + 0.02824605070054531, + -0.05310647189617157, + -0.1122083067893982, + -0.17094524204730988, + -0.08465421944856644, + -0.09679102897644043, + -0.03848385065793991, + 0.040121182799339294, + -0.06661732494831085, + 0.0005764663219451904, + -0.05729356408119202, + -0.04778655245900154, + -0.034835152328014374, + -0.07634143531322479, + -0.05054831504821777, + 0.00597620103508234, + 0.04499154910445213, + -0.03308190405368805, + -0.04915233701467514, + -0.05842791870236397, + 0.003590918146073818, + 0.055837079882621765, + -0.02547842636704445, + -0.018847621977329254, + -0.2073899656534195, + -0.14987564086914062, + -0.03971748799085617, + 0.05886378139257431, + 0.020922083407640457, + -0.039155181497335434, + -0.028855402022600174, + 0.08688661456108093, + -0.1402827501296997, + -0.05810496211051941, + 0.037841811776161194, + -0.04082907736301422, + -0.1191127747297287, + -0.10852136462926865, + 1.6274418830871582, + 0.3678200840950012, + -0.2865799367427826, + -0.05291350558400154, + 0.023858532309532166, + -0.046683818101882935, + -0.2307816743850708, + -0.001670230645686388, + -0.17716962099075317, + -0.16724731028079987, + 0.040194038301706314, + -0.023075448349118233, + -0.01538322027772665, + -0.07914327085018158, + -0.19621343910694122, + -0.11628971993923187, + -0.05851752683520317, + 0.06313594430685043, + 0.017808571457862854, + 0.02447943389415741, + 0.048611078411340714, + -0.009247995913028717, + 0.00789090245962143, + 0.06673033535480499, + 0.0661577433347702, + 0.019111329689621925, + 0.038164373487234116, + 0.029342610388994217, + -0.03547409921884537, + -0.11017149686813354, + -0.11077891290187836, + 0.001108204829506576, + -0.0330691784620285, + -0.05039837956428528, + 0.017638904973864555, + 0.277705579996109, + 0.5606598258018494, + 0.5469182133674622, + 0.13591277599334717, + 0.012421006336808205, + 0.046348799020051956, + -0.02721901424229145, + -0.5645118355751038, + -1.072814702987671, + -0.9852984547615051, + -0.3608386516571045, + -0.010197073221206665, + -0.09785731136798859, + -0.02597353421151638, + 0.4627133309841156, + 1.1483618021011353, + 0.9505703449249268, + 0.17471027374267578, + -0.016467586159706116, + 0.026623696088790894, + 0.04765752702951431, + -0.4000166058540344, + -0.8956774473190308, + -0.6268588304519653, + -0.09439487755298615, + 0.02861764468252659, + -0.004155704285949469, + 0.08989865332841873, + 0.27384331822395325, + 0.6518518328666687, + 0.4184596836566925, + 0.13106893002986908, + 0.0050344159826636314, + 0.007061495911329985, + -0.016157688573002815, + -0.1364346295595169, + -0.27324289083480835, + -0.14245718717575073, + -0.04623992741107941, + -0.015541884116828442, + 0.030779436230659485, + 0.03756715729832649, + 0.01957445964217186, + -0.04964561015367508, + -0.0211405660957098, + 0.044496409595012665, + -0.026335055008530617, + -0.11620140820741653, + -0.11803250014781952, + 0.18242181837558746, + 0.5057784914970398, + 0.5045838952064514, + 0.03748183697462082, + 0.05692485347390175, + 0.1608155369758606, + 0.02245517633855343, + -0.7651812434196472, + -1.5504053831100464, + -1.3563542366027832, + -0.4314505457878113, + -0.028384560719132423, + -0.12238024920225143, + 0.106974296271801, + 1.11427903175354, + 2.173083543777466, + 1.747692346572876, + 0.5455064177513123, + 0.03363418206572533, + 0.11388687789440155, + -0.05905687436461449, + -0.8059568405151367, + -1.6196117401123047, + -1.1898213624954224, + -0.2654758095741272, + -0.004251840524375439, + -0.0916782096028328, + -0.024067873135209084, + 0.22692462801933289, + 0.6695711612701416, + 0.3673460781574249, + -0.017016466706991196, + -0.029604146257042885, + 0.020365707576274872, + 0.03215239942073822, + 0.0070981839671730995, + -0.14026938378810883, + -0.02425236999988556, + 0.059152450412511826, + -0.006319367326796055, + 0.003989882301539183, + 0.048541076481342316, + 0.003988460637629032, + -0.03105335496366024, + -0.08329232037067413, + 0.03226872906088829, + 0.02119620516896248, + -0.0953872874379158, + -0.15174035727977753, + 0.07963212579488754, + 0.29094186425209045, + 0.2690921127796173, + -0.020104877650737762, + 0.024988379329442978, + 0.15326620638370514, + 0.1256464123725891, + -0.40941280126571655, + -0.946648120880127, + -0.8358487486839294, + -0.14284957945346832, + -0.07980851829051971, + -0.1435413807630539, + 0.038134895265102386, + 0.8021518588066101, + 1.552701473236084, + 1.2496209144592285, + 0.38152581453323364, + 0.07136060297489166, + 0.14329172670841217, + -0.06546801328659058, + -0.5923707485198975, + -1.253793478012085, + -0.9458200335502625, + -0.156633198261261, + -0.04217473417520523, + -0.11199303716421127, + -0.07520301640033722, + 0.15331010520458221, + 0.4794600307941437, + 0.2449675053358078, + -0.10396319627761841, + 0.0034801275469362736, + 0.04475663974881172, + 0.024035215377807617, + 0.056806568056344986, + -0.07363307476043701, + -0.001563104335218668, + 0.05157755687832832, + 0.043718185275793076, + 0.02102719619870186, + 0.11859089881181717, + 0.08675580471754074, + -0.13180124759674072, + -0.15522590279579163, + 0.03273458778858185, + -0.0019622649997472763, + 0.1011638194322586, + -0.10800585150718689, + -0.6884365677833557, + -0.5495791435241699, + 0.0780424103140831, + 0.33674973249435425, + -0.21274283528327942, + -0.4183696210384369, + -0.8053947687149048, + 0.03347628563642502, + 1.3938312530517578, + 0.9454176425933838, + -0.012210174463689327, + 0.04924672842025757, + 0.16284359991550446, + 1.1340152025222778, + 2.0020322799682617, + 0.2796843647956848, + -0.968036413192749, + -0.5768532752990723, + 0.17757350206375122, + 0.37485063076019287, + 0.11534234136343002, + -1.2916942834854126, + -1.692176103591919, + -0.30523377656936646, + 0.14307916164398193, + 0.03928302228450775, + -0.19196964800357819, + -0.4533900022506714, + -0.3294944167137146, + 0.5480389595031738, + 0.4497548043727875, + 0.2170887440443039, + -0.05817069113254547, + -0.06957870721817017, + 0.03169052675366402, + 0.23751793801784515, + 0.0823391005396843, + -0.04811413958668709, + -0.051265716552734375, + -0.0395645909011364, + -0.03849785774946213, + 0.04607917368412018, + 0.09946659207344055, + -0.029992828145623207, + -0.05369366332888603, + -0.005230880342423916, + 0.012808755040168762, + 0.1821947544813156, + 0.05478882044553757, + -0.47736144065856934, + -0.44480830430984497, + -0.036321353167295456, + 0.13646431267261505, + -0.04045571759343147, + -0.21837295591831207, + -0.6888197660446167, + -0.08431777358055115, + 0.96018385887146, + 0.6788493990898132, + 0.011028020642697811, + 0.05917810648679733, + 0.02488739602267742, + 0.6898419857025146, + 1.4259209632873535, + 0.13193827867507935, + -0.8078985810279846, + -0.31056249141693115, + 0.018122224137187004, + 0.137860506772995, + 0.051947757601737976, + -0.9757952094078064, + -1.1060559749603271, + 0.06675099581480026, + 0.2091575562953949, + -0.029623042792081833, + -0.0705878809094429, + -0.18514159321784973, + -0.07947035878896713, + 0.5719470381736755, + 0.2286168485879898, + -0.03433626517653465, + 0.0036030709743499756, + 0.006251791957765818, + 0.04144154116511345, + 0.08598234504461288, + -0.050599172711372375, + -0.10440917313098907, + -0.02927244082093239, + -0.04102599248290062, + -0.07101748138666153, + -0.03579306975007057, + 0.03586365282535553, + 0.06752362847328186, + 0.048901572823524475, + -0.020898710936307907, + -0.009411930106580257, + 0.10169848799705505, + 0.1812015175819397, + -0.014482695609331131, + -0.12548771500587463, + -0.060731250792741776, + -0.034499138593673706, + 0.0829617902636528, + 0.04616715386509895, + -0.20867496728897095, + -0.1990129053592682, + 0.1773940473794937, + 0.13156233727931976, + -0.03437860682606697, + 0.04012921825051308, + -0.11132699251174927, + -0.023460939526557922, + 0.2713286876678467, + -0.06662362813949585, + -0.2709292471408844, + -0.0030232456047087908, + -0.10379529744386673, + -0.07136038690805435, + 0.03757762163877487, + -0.20515622198581696, + -0.1231834888458252, + 0.26915228366851807, + 0.0998353362083435, + -0.031466737389564514, + 0.04657471179962158, + 0.07664929330348969, + 0.10308870673179626, + 0.23429608345031738, + -0.06942534446716309, + -0.09051290899515152, + 0.03243685141205788, + 0.04053235426545143, + -0.021392958238720894, + -0.05330868810415268, + -0.11525140702724457, + -0.03889385238289833, + 0.01636480540037155, + -0.009352890774607658, + 0.13151532411575317, + -0.14738643169403076, + -0.18289834260940552, + 0.15955400466918945, + -0.001023759599775076, + 0.028809679672122, + 0.012261062860488892, + 0.29654747247695923, + -0.285063236951828, + -0.40187928080558777, + 0.3713407516479492, + 0.009383893571794033, + -0.023022817447781563, + -0.003799814498052001, + 0.48470190167427063, + -0.43402406573295593, + -0.5858806371688843, + 0.5751441717147827, + 0.05045031011104584, + -0.05559438094496727, + -0.02045449987053871, + 0.5281224250793457, + -0.5058223605155945, + -0.5950849056243896, + 0.6492323279380798, + 0.013408469036221504, + -0.05940670147538185, + -0.0044364179484546185, + 0.3112560212612152, + -0.34908774495124817, + -0.42427319288253784, + 0.43349501490592957, + 0.03724945709109306, + -0.05263671651482582, + -0.010485195554792881, + 0.1261255145072937, + -0.1349790245294571, + -0.2524855136871338, + 0.24608080089092255, + 0.036001257598400116, + -0.028843939304351807, + 0.0056989979930222034, + 0.04458172619342804, + -0.06122935935854912, + -0.166972354054451, + 0.14557687938213348, + 0.018050044775009155, + 0.032598987221717834, + -0.0055792503990232944, + 0.24355076253414154, + -0.21433626115322113, + -0.29646870493888855, + 0.1958809792995453, + 0.015435033477842808, + 0.05235098674893379, + 0.010786890983581543, + 0.47903597354888916, + -0.4127257168292999, + -0.6203306317329407, + 0.47024452686309814, + 0.0823090448975563, + -0.04538045823574066, + -0.004072466865181923, + 0.7509317994117737, + -0.6508772969245911, + -0.8481631278991699, + 0.7875698208808899, + 0.0966777428984642, + -0.10461349785327911, + 0.0063789174892008305, + 0.7535857558250427, + -0.8082649111747742, + -0.8165622353553772, + 0.9064085483551025, + 0.04986630380153656, + -0.10200339555740356, + 0.0314355194568634, + 0.46324053406715393, + -0.5523763298988342, + -0.5632953643798828, + 0.6378755569458008, + 0.07833302766084671, + -0.07979781180620193, + 0.031164664775133133, + 0.1967470794916153, + -0.21681970357894897, + -0.29283079504966736, + 0.3367702066898346, + 0.034929461777210236, + -0.047199901193380356, + -0.0033645557705312967, + 0.05454660952091217, + -0.11264829337596893, + -0.190998375415802, + 0.17961400747299194, + 0.0009085010970011353, + -0.0001827089727157727, + 0.04841821268200874, + 0.019923821091651917, + -0.07004066556692123, + -0.10590090602636337, + 0.054114967584609985, + 0.04302384704351425, + 0.00462615629658103, + 0.022948985919356346, + 0.1673787385225296, + -0.1319379210472107, + -0.2711219787597656, + 0.2387620061635971, + 0.05667697265744209, + -0.018639734014868736, + -0.07672597467899323, + 0.3503187298774719, + -0.2981504797935486, + -0.38647517561912537, + 0.4072522521018982, + 0.010913677513599396, + -0.05246961489319801, + -0.04058554396033287, + 0.39216771721839905, + -0.3605193495750427, + -0.34857264161109924, + 0.46899959444999695, + -0.03358001261949539, + -0.05188553035259247, + -0.023204902186989784, + 0.17140533030033112, + -0.2120431810617447, + -0.2144550085067749, + 0.2837989032268524, + -0.0191226527094841, + -0.020922169089317322, + 0.004324179142713547, + 0.038136694580316544, + -0.042803723365068436, + -0.11487454175949097, + 0.11820490658283234, + 0.003412557765841484, + 0.0035020115319639444, + 0.03646541014313698, + -0.010104459710419178, + -0.010897459462285042, + -0.09292570501565933, + 0.06823977828025818, + 0.02677192911505699, + 0.020071662962436676, + 0.005776307079941034, + 0.02613351307809353, + 0.017107944935560226, + -0.0002623539185151458, + -0.039298396557569504, + -0.0314190648496151, + -0.019773684442043304, + -0.01924789510667324, + 0.04253160580992699, + 0.09694722294807434, + 0.1925637573003769, + 0.1901547759771347, + 0.09470294415950775, + -0.00296174269169569, + -0.03602522239089012, + 0.03572473302483559, + 0.08787581324577332, + 0.1773553043603897, + 0.20970025658607483, + 0.14899243414402008, + 0.05427362397313118, + -0.032429151237010956, + 0.023915717378258705, + 0.06557436287403107, + 0.13488733768463135, + 0.17550915479660034, + 0.17485061287879944, + 0.10260436683893204, + -0.005381361581385136, + -0.05573735386133194, + -0.09410752356052399, + -0.07940010726451874, + -0.03424998000264168, + 0.007975265383720398, + 0.028827181085944176, + 0.023788832128047943, + -0.02962818741798401, + -0.13474339246749878, + -0.22529757022857666, + -0.20413516461849213, + -0.14711618423461914, + -0.05960607901215553, + 0.04579121991991997, + 0.005325576290488243, + -0.11592217534780502, + -0.2260522097349167, + -0.2467145025730133, + -0.22054187953472137, + -0.13919179141521454, + 0.0016459478065371513, + 0.0515579916536808, + 0.060555730015039444, + 0.040788713842630386, + -0.017907800152897835, + -0.026459651067852974, + -0.02488812990486622, + 0.015644825994968414, + 0.10543125867843628, + 0.19312354922294617, + 0.28380078077316284, + 0.28878358006477356, + 0.16968156397342682, + 0.04848042502999306, + -0.00986899808049202, + 0.06337545067071915, + 0.16356752812862396, + 0.2444516271352768, + 0.29273414611816406, + 0.2314801961183548, + 0.12695762515068054, + -0.022283215075731277, + 0.018402203917503357, + 0.07152476161718369, + 0.14247483015060425, + 0.18759845197200775, + 0.20828258991241455, + 0.14114585518836975, + -0.047197990119457245, + -0.13794781267642975, + -0.17509934306144714, + -0.1696663200855255, + -0.1206701323390007, + -0.036128126084804535, + 0.007180679589509964, + 0.006984225939959288, + -0.09600912779569626, + -0.22975720465183258, + -0.33287662267684937, + -0.2942708134651184, + -0.20305578410625458, + -0.08411446958780289, + 0.042896877974271774, + -0.020053744316101074, + -0.16365791857242584, + -0.3145587742328644, + -0.3321540057659149, + -0.2667454183101654, + -0.1542910486459732, + -0.006954069249331951, + 0.020191870629787445, + 0.014010002836585045, + 0.0016916356980800629, + -0.04649524390697479, + -0.014931428246200085, + -0.017954425886273384, + -0.020003901794552803, + 0.03831968829035759, + 0.08447518199682236, + 0.14068123698234558, + 0.13400419056415558, + 0.08205568045377731, + -0.0004489773709792644, + -0.019211264327168465, + 0.023363608866930008, + 0.08738930523395538, + 0.12299696356058121, + 0.13070489466190338, + 0.09040816128253937, + 0.03286544978618622, + -0.006979941390454769, + -0.0010930931894108653, + 0.04313739389181137, + 0.10121051222085953, + 0.11390950530767441, + 0.11383924633264542, + 0.06694260239601135, + -0.00425445893779397, + -0.0666416585445404, + -0.09225274622440338, + -0.0977785512804985, + -0.07118111103773117, + -0.026749763637781143, + -0.019425569102168083, + 0.03321055322885513, + -0.0033978468272835016, + -0.08309262245893478, + -0.15557922422885895, + -0.14969374239444733, + -0.07188998907804489, + -0.018716221675276756, + 0.022834330797195435, + 0.004232254344969988, + -0.04141783341765404, + -0.125192329287529, + -0.14545302093029022, + -0.12225300818681717, + -0.05844716727733612, + 0.010607236064970493, + 0.024218380451202393, + -0.002702374942600727, + -0.030814893543720245, + 0.03507756441831589, + -0.0506589449942112, + 0.03415676951408386, + 0.0011444400297477841, + 0.0026324463542550802, + 0.028514407575130463, + -0.01849454641342163, + -0.030959082767367363, + -0.05565863475203514, + 0.05771413818001747, + 0.003916156478226185, + -0.004474544432014227, + 0.04403551295399666, + 0.1733711212873459, + -0.37650829553604126, + 0.22322984039783478, + 0.0032540319953113794, + -0.01139416079968214, + -0.039046600461006165, + 0.0021948080975562334, + 0.5777754783630371, + -1.1944804191589355, + 0.769478976726532, + -0.1349843591451645, + 0.0004430754925124347, + -0.0061850035563111305, + -0.08340868353843689, + 0.8327823877334595, + -1.649588942527771, + 1.126111388206482, + -0.2918313145637512, + 0.003614947199821472, + 0.0016799914883449674, + -0.03255167230963707, + 0.6123784184455872, + -1.1993682384490967, + 0.8305437564849854, + -0.13622376322746277, + 0.00905851274728775, + -0.006772476714104414, + 0.07578610628843307, + 0.05859832838177681, + -0.4543764293193817, + 0.26330503821372986, + 0.0259060300886631, + -0.0007997890934348106, + 0.01269856933504343, + 0.006897627376019955, + -0.02491801232099533, + -0.03139931708574295, + 0.0028456314466893673, + 0.0008253560517914593, + -0.01086023822426796, + -0.004186873324215412, + 0.06299160420894623, + -0.039931319653987885, + -0.09315146505832672, + 0.05495935305953026, + 0.027547571808099747, + -0.010900916531682014, + -0.025233760476112366, + 0.060600072145462036, + 0.21010243892669678, + -0.5445898771286011, + 0.35070353746414185, + -0.033771682530641556, + -0.0269146841019392, + -0.025363197550177574, + -0.021729450672864914, + 0.70921790599823, + -1.4368270635604858, + 0.9582043290138245, + -0.1708265244960785, + 0.010022420436143875, + -0.032301150262355804, + -0.08667651563882828, + 1.0338889360427856, + -1.913576364517212, + 1.262008547782898, + -0.23795078694820404, + -0.032233912497758865, + -0.01397701445966959, + -0.05402921140193939, + 0.7621430158615112, + -1.387437343597412, + 0.8621506094932556, + -0.14765247702598572, + -0.004747485741972923, + 0.0017516895895823836, + 0.08154146373271942, + 0.16601374745368958, + -0.5324177742004395, + 0.27442997694015503, + 0.03274058923125267, + -0.008812552317976952, + 0.005774920806288719, + 0.04165825620293617, + -0.011749272234737873, + -0.01953396573662758, + -0.009672109968960285, + 0.01170953270047903, + 0.003071938641369343, + -0.018979815766215324, + 0.062123894691467285, + -0.004921444226056337, + -0.03380037844181061, + 0.01310884952545166, + 0.007953890599310398, + -0.0012086924398317933, + -0.03317898139357567, + -0.0015596294542774558, + 0.08166785538196564, + -0.2291223704814911, + 0.11783571541309357, + -0.016078786924481392, + 0.018957575783133507, + 0.025793947279453278, + -0.09036394208669662, + 0.3833881616592407, + -0.5794023871421814, + 0.4610825777053833, + -0.14165280759334564, + -0.007412370759993792, + 0.05252876877784729, + -0.21435455977916718, + 0.6177686452865601, + -0.8516795635223389, + 0.667263925075531, + -0.22572898864746094, + -0.004465761594474316, + 0.02589319832623005, + -0.1893543303012848, + 0.43213585019111633, + -0.6462821364402771, + 0.434274822473526, + -0.15750259160995483, + -0.01198036689311266, + -2.4281514924950898e-05, + 0.039562296122312546, + 0.11126027256250381, + -0.23193514347076416, + 0.1412443071603775, + -0.011839920654892921, + 0.007880321703851223, + 0.02950354479253292, + 0.011689653620123863, + -0.07272310554981232, + -0.03319466486573219, + -0.003948990721255541, + 0.03549842908978462, + -0.02165558747947216, + -0.09912239760160446, + -0.08742356300354004, + 0.30591821670532227, + 0.23934677243232727, + 0.02658180706202984, + -0.022127188742160797, + -0.02769642136991024, + 0.16399237513542175, + 0.5140998959541321, + 0.007951628416776657, + -0.5589093565940857, + -0.24106110632419586, + -0.02753414213657379, + 0.06947467476129532, + 0.048558495938777924, + -0.5370690822601318, + -0.761831521987915, + 0.16272802650928497, + 0.29426246881484985, + 0.07943751662969589, + -0.022394873201847076, + -0.217612162232399, + -0.03093647211790085, + 0.5945476293563843, + 0.2873935103416443, + -0.16481661796569824, + -0.02931203693151474, + -0.029083512723445892, + 0.06754925847053528, + 0.20200076699256897, + -0.07271742075681686, + -0.1976277083158493, + -0.04189611226320267, + 0.06403793394565582, + -0.00022445111244451255, + -0.01032529678195715, + -0.03415631130337715, + 0.009091783314943314, + 0.04317992925643921, + 0.07196266949176788, + -0.025028688833117485, + -0.02722775563597679, + -0.017168480902910233, + -0.027666645124554634, + -0.06734028458595276, + 0.10843724757432938, + 0.08066407591104507, + -0.027849983423948288, + -0.0045820740051567554, + -0.03388727456331253, + 0.16772156953811646, + 0.651636004447937, + 0.34874194860458374, + -0.1454945057630539, + -0.18056720495224, + 0.11703842133283615, + 0.43017855286598206, + 0.7624525427818298, + -0.3420296907424927, + -1.272199273109436, + -0.5284644365310669, + -0.005667245015501976, + 0.08240436762571335, + -0.13299596309661865, + -1.3164156675338745, + -1.659982442855835, + 0.19898656010627747, + 0.6253566741943359, + 0.25137946009635925, + -0.18244975805282593, + -0.5360167622566223, + -0.06195700913667679, + 1.2547520399093628, + 1.0296341180801392, + 0.10651036351919174, + -0.023540280759334564, + -0.07594245672225952, + 0.1492130160331726, + 0.5033117532730103, + 0.09394379705190659, + -0.22459803521633148, + -0.22473134100437164, + -0.04738321527838707, + 0.04127531498670578, + 0.0682951882481575, + -0.02095615118741989, + -0.1233135387301445, + -0.10028401762247086, + -0.008111395873129368, + -0.000617706507910043, + 0.018859047442674637, + 0.028446361422538757, + -0.06159031391143799, + -0.1292838156223297, + 0.051308393478393555, + 0.11001072078943253, + -0.02056661807000637, + -0.012175443582236767, + -0.1313694268465042, + 0.0067574759013950825, + 0.4612729251384735, + 0.323080450296402, + -0.09392253309488297, + -0.1256203055381775, + 0.03537299111485481, + 0.2556088864803314, + 0.6467183232307434, + -0.16340143978595734, + -0.8799455165863037, + -0.3312987685203552, + 0.01464154850691557, + 0.07046713680028915, + 0.053634822368621826, + -0.8514915108680725, + -1.176972508430481, + 0.2056443840265274, + 0.4998764395713806, + 0.1268644779920578, + -0.10905193537473679, + -0.3750888705253601, + -0.06701061874628067, + 0.9052186608314514, + 0.6792045831680298, + -0.00323892361484468, + -0.0007412935374304652, + -0.03608793020248413, + 0.1009129211306572, + 0.36775916814804077, + 0.035214491188526154, + -0.2273784875869751, + -0.15815992653369904, + -0.004773923195898533, + 0.06374036520719528, + 0.04737555980682373, + -0.0563247986137867, + -0.09587392956018448, + -0.043853096663951874, + 0.032572731375694275, + -0.0036250585690140724, + 0.07889056205749512, + -0.03589344769716263, + -0.019771328195929527, + 0.04937156289815903, + 0.039052557200193405, + -0.013377528637647629, + -0.0841481015086174, + -0.03358105197548866, + -0.2128981053829193, + -0.14468812942504883, + 0.14675867557525635, + 0.2550889551639557, + 0.22369499504566193, + -0.0032973098568618298, + 0.006679064594209194, + -0.11752036958932877, + 0.025247232988476753, + 0.23064176738262177, + 0.25043538212776184, + 0.3474777638912201, + 0.2151806503534317, + 0.051294319331645966, + 0.16301114857196808, + 0.25422143936157227, + -0.1796918362379074, + -0.6128425598144531, + -0.42049655318260193, + 0.07740531116724014, + -0.007960617542266846, + 0.2504507601261139, + 0.2932300865650177, + -0.5157915949821472, + -1.2904177904129028, + -1.0362532138824463, + -0.22443994879722595, + 0.007411653641611338, + 0.16024430096149445, + 0.33939966559410095, + -0.2748318016529083, + -0.8487470149993896, + -0.5955387949943542, + 0.033155132085084915, + -0.09185351431369781, + -0.05639262869954109, + 0.17084303498268127, + 0.11292264610528946, + -0.046329669654369354, + 0.11495561897754669, + 0.31740760803222656, + -0.13903948664665222, + 0.05507560819387436, + 0.10180198401212692, + -0.1369788944721222, + -0.10618618875741959, + -0.001083499751985073, + 0.16340164840221405, + 0.07591762393712997, + 0.3417445123195648, + 0.27897438406944275, + -0.32192930579185486, + -0.5731648206710815, + -0.46150147914886475, + -0.03230089321732521, + 0.04096771031618118, + 0.22242987155914307, + 0.027000218629837036, + -0.4113498628139496, + -0.433158278465271, + -0.5252256393432617, + -0.3510502874851227, + -0.133863165974617, + -0.38554033637046814, + -0.45547229051589966, + 0.2475612610578537, + 1.154951572418213, + 0.8282179236412048, + -0.13197137415409088, + -0.03350961208343506, + -0.5282800197601318, + -0.5297923684120178, + 0.9037952423095703, + 2.516275405883789, + 2.086421489715576, + 0.3573826849460602, + -0.010694397613406181, + -0.31418153643608093, + -0.5325371026992798, + 0.48083701729774475, + 1.7732245922088623, + 1.2747145891189575, + -0.06401863694190979, + 0.14296381175518036, + 0.07267159968614578, + -0.28001847863197327, + -0.29204103350639343, + 0.12853951752185822, + -0.1998838633298874, + -0.6375644207000732, + 0.06310836225748062, + -0.020014479756355286, + -0.08150970935821533, + 0.08175478130578995, + 0.07667485624551773, + 0.0025236753281205893, + -0.08504530042409897, + -0.035742271691560745, + -0.1332666128873825, + -0.15150736272335052, + 0.18459312617778778, + 0.3363596200942993, + 0.2501969635486603, + 0.029292423278093338, + -0.060296736657619476, + -0.1142202764749527, + -0.05918247997760773, + 0.18826954066753387, + 0.2183520495891571, + 0.21247169375419617, + 0.14935970306396484, + 0.09923429787158966, + 0.21808095276355743, + 0.21930061280727386, + -0.060535889118909836, + -0.5729222297668457, + -0.4199080169200897, + 0.058897778391838074, + 0.050647757947444916, + 0.2784770131111145, + 0.2754706144332886, + -0.40136128664016724, + -1.3269731998443604, + -1.124815583229065, + -0.11878778040409088, + -0.005137663800269365, + 0.17839783430099487, + 0.2115524858236313, + -0.24165289103984833, + -0.9655010104179382, + -0.7425088286399841, + 0.0304054357111454, + -0.07012742757797241, + -0.015557953156530857, + 0.1128007024526596, + 0.18957749009132385, + -0.07996463775634766, + 0.09505810588598251, + 0.34419506788253784, + -0.3072076439857483, + 0.03868290036916733, + 0.11494885385036469, + 0.03748936951160431, + 0.0797261893749237, + -0.003397951368242502, + -0.07380004972219467, + -0.11507676541805267, + -0.10298885405063629, + 0.10698320716619492, + 0.06602972000837326, + 0.08226803690195084, + 0.0037747276946902275, + -0.162277951836586, + 0.01671667955815792, + 0.09137773513793945, + 0.18799471855163574, + 0.04144813120365143, + 0.1285877376794815, + 0.1820434182882309, + 0.04940629005432129, + 0.0991915687918663, + 0.10219171643257141, + -0.013141660951077938, + -0.051191627979278564, + 0.05468929558992386, + 0.087598517537117, + 0.15897324681282043, + 0.11863455921411514, + -0.00814050156623125, + -0.07701541483402252, + -0.14013728499412537, + -0.044140227138996124, + -0.05328791216015816, + 0.06760499626398087, + 0.12053386867046356, + 0.09780212491750717, + -0.053725965321063995, + -0.07915244251489639, + -0.0032519602682441473, + 0.019637396559119225, + 0.07848430424928665, + 0.019138827919960022, + 0.1460287868976593, + 0.1281038075685501, + 0.024417784065008163, + 0.059176862239837646, + 0.0658111497759819, + -0.016405148431658745, + -0.18877744674682617, + 0.16666102409362793, + 0.1610611230134964, + 0.08374520391225815, + 0.11570518463850021, + 0.11903064697980881, + 0.1294964700937271, + 0.06379758566617966, + 0.08417274057865143, + 0.12754113972187042, + 0.025328608229756355, + 0.05170705169439316, + 0.0835295170545578, + 0.07477264851331711, + 0.11244285851716995, + 0.11559426784515381, + 0.045258160680532455, + -0.14825093746185303, + -0.08153342455625534, + 0.06288623809814453, + 0.11952362209558487, + 0.11784297972917557, + 0.011141132563352585, + -0.21666541695594788, + -0.29976174235343933, + -0.2279169261455536, + -0.11828474700450897, + 0.12436322867870331, + 0.10465826094150543, + -0.09751085937023163, + -0.292611300945282, + -0.37374064326286316, + -0.31437963247299194, + -0.25637903809547424, + 0.06173908710479736, + 0.14131486415863037, + 0.008434675633907318, + -0.23816508054733276, + -0.30330890417099, + -0.22094152867794037, + -0.11608295142650604, + 0.13235151767730713, + 0.15353602170944214, + 0.15839524567127228, + 0.012247815728187561, + -0.08126968890428543, + -0.003756331978365779, + 0.10660683363676071, + 0.21976575255393982, + -0.04188326746225357, + 0.15462253987789154, + 0.06303395330905914, + 0.006879634689539671, + 0.008284888230264187, + 0.07084798067808151, + 0.1211942657828331, + 0.10190404951572418, + 0.02935362420976162, + -0.05645999684929848, + -0.16800500452518463, + -0.1850246787071228, + -0.09476880729198456, + -0.025327544659376144, + 0.054355036467313766, + -0.035813912749290466, + -0.18694879114627838, + -0.34871891140937805, + -0.3151862621307373, + -0.1943007856607437, + -0.09755205363035202, + 0.014881589449942112, + -0.14875493943691254, + -0.37112873792648315, + -0.37739917635917664, + -0.3241480886936188, + -0.2915399968624115, + -0.11268249899148941, + -0.019726404920220375, + -0.2510305941104889, + -0.38005372881889343, + -0.3622463345527649, + -0.2932804226875305, + -0.28574010729789734, + -0.1505027860403061, + -0.004947682376950979, + -0.18587322533130646, + -0.34759166836738586, + -0.28965193033218384, + -0.21052972972393036, + -0.18780536949634552, + -0.07400713860988617, + 0.11154936999082565, + -0.03556853160262108, + -0.1896934062242508, + -0.18135806918144226, + -0.10117948800325394, + -0.0393117293715477, + 0.06517928093671799, + -0.016659021377563477, + -0.011290309950709343, + -0.007930322550237179, + 0.008189777843654156, + 0.03678786754608154, + 0.021890517324209213, + 0.0034292477648705244, + 0.02200375869870186, + 0.0014921070542186499, + -0.0800287202000618, + -0.17657361924648285, + -0.18702608346939087, + -0.12880444526672363, + -0.022084584459662437, + 0.026420501992106438, + -0.023968446999788284, + -0.07948111742734909, + -0.16741475462913513, + -0.18733707070350647, + -0.16539834439754486, + -0.07347387820482254, + -0.009723886847496033, + -0.02016977220773697, + -0.061092622578144073, + -0.13145211338996887, + -0.15919029712677002, + -0.15043555200099945, + -0.10107766091823578, + 0.0016151965828612447, + 0.0627974420785904, + 0.08695066720247269, + 0.11727584898471832, + 0.11745581030845642, + 0.11329426616430283, + 0.0533670075237751, + -0.016355818137526512, + 0.008450252935290337, + 0.06448577344417572, + 0.1538505256175995, + 0.21232697367668152, + 0.14713847637176514, + 0.039088234305381775, + -0.015588105656206608, + 0.026483291760087013, + 0.060862988233566284, + 0.18265819549560547, + 0.23042462766170502, + 0.168768972158432, + 0.034099943935871124, + -0.018249109387397766, + -0.0321880541741848, + -0.03254542127251625, + -0.03061222843825817, + -0.0026304698549211025, + 0.017764942720532417, + 0.010707704350352287, + 0.009254949167370796, + -0.04533161595463753, + -0.1483704000711441, + -0.2637183666229248, + -0.2678598165512085, + -0.1737881749868393, + -0.049990858882665634, + 0.013515918515622616, + -0.054345693439245224, + -0.1467861533164978, + -0.24911582469940186, + -0.2831358015537262, + -0.22300836443901062, + -0.13739243149757385, + -0.017879672348499298, + -0.040345460176467896, + -0.09990613907575607, + -0.16936856508255005, + -0.2266550064086914, + -0.2020808756351471, + -0.1509508341550827, + 0.014163740910589695, + 0.07591170817613602, + 0.09185601025819778, + 0.10455341637134552, + 0.09514842182397842, + 0.09877350926399231, + 0.053898438811302185, + 0.005704578943550587, + 0.0591997392475605, + 0.13600079715251923, + 0.21777905523777008, + 0.2574957311153412, + 0.20117221772670746, + 0.11415109038352966, + -0.001181072206236422, + 0.09470006823539734, + 0.18978413939476013, + 0.3073742389678955, + 0.36875811219215393, + 0.3069853186607361, + 0.1708926260471344, + -0.0325310118496418, + -0.02656698040664196, + 0.016060845926404, + 0.02459372952580452, + 0.04165660962462425, + 0.033969976007938385, + 0.012855498120188713, + 0.030497560277581215, + 0.004896117839962244, + -0.030887477099895477, + -0.13454437255859375, + -0.1294785887002945, + -0.06398608535528183, + 0.016156472265720367, + 0.03577340394258499, + -0.0033482143189758062, + -0.07112833857536316, + -0.16465041041374207, + -0.1621057391166687, + -0.09478478878736496, + -0.03555302321910858, + -0.001592929707840085, + -0.01719600521028042, + -0.06598587334156036, + -0.1411861628293991, + -0.1496778130531311, + -0.11535074561834335, + -0.0905962884426117, + -0.013807609677314758, + 0.029542237520217896, + 0.039138730615377426, + 0.03988270089030266, + 0.02665030211210251, + 0.049553126096725464, + -0.0015685928519815207, + -0.018007200211286545, + 0.009533192962408066, + 0.06910547614097595, + 0.1034330427646637, + 0.15017645061016083, + 0.10221225768327713, + 0.020978443324565887, + -0.023747621104121208, + 0.02295384369790554, + 0.09313814342021942, + 0.1771395057439804, + 0.21169933676719666, + 0.17989481985569, + 0.05862005427479744, + -0.004540165886282921, + 0.021994179114699364, + -0.003493826137855649, + -0.000224211675231345, + 0.031808022409677505, + -0.05090906098484993, + 0.001970196608453989, + 0.01633802428841591, + 0.0049764602445065975, + 0.0006027702474966645, + -0.005952450912445784, + -0.009886081330478191, + -0.08520589768886566, + 0.030780712142586708, + 0.00037104589864611626, + 0.011886775493621826, + -0.023506291210651398, + 0.08029806613922119, + -0.005086984951049089, + -0.07738454639911652, + 0.06721897423267365, + -0.02397127076983452, + 0.006669329944998026, + -0.016343094408512115, + 0.06056324020028114, + 0.15656796097755432, + -0.49836501479148865, + 0.2475810945034027, + -0.009270203299820423, + -0.006855266634374857, + 0.0034896093420684338, + -0.027938276529312134, + 0.5722692012786865, + -1.1357109546661377, + 0.5644665956497192, + 0.015787361189723015, + -0.015141892246901989, + -0.0032788251992315054, + -0.04797150194644928, + 0.6196744441986084, + -1.1540743112564087, + 0.6065864562988281, + 0.0019708566833287477, + 0.006332532037049532, + 0.014192940667271614, + 0.03773411735892296, + 0.27323007583618164, + -0.594700813293457, + 0.2488076239824295, + -0.008853388018906116, + 0.005692378617823124, + 0.000576167949475348, + -0.027197014540433884, + 0.022015029564499855, + -0.02571249194443226, + 0.004507753532379866, + -0.002439734758809209, + -0.01994609646499157, + 0.03601142391562462, + 0.008136607706546783, + 0.01658148691058159, + -0.06548810750246048, + 0.022721221670508385, + -0.0038820707704871893, + -0.0007800398161634803, + 0.001392301986925304, + 0.09576108306646347, + -0.014628835022449493, + -0.14505760371685028, + 0.07135403156280518, + -0.00839388556778431, + -0.004555124789476395, + -0.04466082155704498, + 0.1456393599510193, + 0.3475525975227356, + -0.7879117131233215, + 0.36262738704681396, + 0.008226356469094753, + 0.0055343699641525745, + -0.061139706522226334, + 0.08975803852081299, + 0.9340736269950867, + -1.7307822704315186, + 0.796896755695343, + -0.024700213223695755, + -0.013090251013636589, + -0.05148586630821228, + 0.050525497645139694, + 0.927090048789978, + -1.7473385334014893, + 0.7727715373039246, + -0.005721901543438435, + 0.010676853358745575, + -0.012798544019460678, + 0.11131046712398529, + 0.4181194007396698, + -0.8475598096847534, + 0.33206430077552795, + 0.018843427300453186, + 0.0006885005859658122, + 0.027498219162225723, + 0.00207257061265409, + 0.0032615051604807377, + -0.021950624883174896, + -0.008452882058918476, + -0.007631891872733831, + -0.028561849147081375, + 0.04865337535738945, + -0.0023105579894036055, + -0.026170270517468452, + -0.011794357560575008, + 0.004327487666159868, + 0.01756221242249012, + 0.0011611212976276875, + -0.008793564513325691, + 0.0741758644580841, + -0.057649385184049606, + -0.006000686902552843, + -0.022717488929629326, + -0.0047143916599452496, + 0.005709030199795961, + -0.05611564591526985, + 0.05792170390486717, + 0.1873699128627777, + -0.3856293857097626, + 0.1371920108795166, + 0.018953431397676468, + 0.015250314958393574, + -0.0016827551880851388, + -0.08515634387731552, + 0.6517581939697266, + -0.9557326436042786, + 0.46986615657806396, + -0.014306572265923023, + -0.01625121757388115, + -0.016088897362351418, + -0.13429272174835205, + 0.6437729001045227, + -1.0167845487594604, + 0.5061463117599487, + 0.00879831612110138, + -0.008598369546234608, + 0.02747279778122902, + 0.007245234213769436, + 0.2527446150779724, + -0.47163763642311096, + 0.15560215711593628, + 0.005050336476415396, + -0.024848125874996185, + -0.0006449198699556291, + -0.008673148229718208, + -0.06940636038780212, + -0.016248086467385292, + 0.1250494420528412, + 0.026387182995676994, + 0.009615709073841572, + -0.0025482974015176296, + -0.04534498229622841, + -0.2626228630542755, + -0.2753732204437256, + 0.052055053412914276, + 0.010792221873998642, + 0.007360508665442467, + 0.10271529853343964, + 0.1113760769367218, + -0.31120774149894714, + -0.49849262833595276, + -0.2206398844718933, + 0.04994913563132286, + 0.054614756256341934, + 0.27786919474601746, + 0.56647789478302, + 0.20970205962657928, + -0.22717078030109406, + -0.17321231961250305, + -0.07836200296878815, + -0.09607961028814316, + 0.10685958713293076, + 0.40848156809806824, + 0.34087467193603516, + -0.005242985673248768, + -0.0682876780629158, + -0.0694413110613823, + -0.1886596381664276, + -0.04473332315683365, + 0.18096435070037842, + 0.1961163580417633, + 0.0014336564345285296, + 0.014584851451218128, + 0.0462430939078331, + -0.1556192934513092, + -0.12809665501117706, + 0.0213937908411026, + 0.10984069108963013, + -0.023050926625728607, + -0.013447473756968975, + 0.007857509888708591, + -0.027979737147688866, + -0.04768490046262741, + -0.09350565075874329, + -0.1659490317106247, + 0.007927919737994671, + 0.26641780138015747, + 0.03398526459932327, + 0.02118881419301033, + -0.006898822728544474, + -0.15209096670150757, + -0.4939330220222473, + -0.42655149102211, + 0.08215854316949844, + 0.02115131914615631, + 0.08892140537500381, + 0.2164168655872345, + 0.12431265413761139, + -0.47813764214515686, + -0.6588870882987976, + -0.3097454905509949, + 0.0837375745177269, + 0.1548176258802414, + 0.49661529064178467, + 0.7337944507598877, + 0.1966201215982437, + -0.29367199540138245, + -0.2547970116138458, + -0.11655519157648087, + -0.11720486730337143, + 0.21941716969013214, + 0.5902130603790283, + 0.42572125792503357, + 0.020460324361920357, + -0.12768393754959106, + -0.12030418962240219, + -0.2582310736179352, + -0.0355166494846344, + 0.2766987085342407, + 0.28080257773399353, + 0.08665957301855087, + 0.027141664177179337, + 0.02690703421831131, + -0.25276950001716614, + -0.23180679976940155, + 0.015180152840912342, + 0.11523276567459106, + 0.041165824979543686, + 0.017444534227252007, + 0.0009439520072191954, + -0.025763530284166336, + -0.022880665957927704, + -0.024819007143378258, + -0.04901815578341484, + 0.027672944590449333, + 0.11211585998535156, + 0.024664992466568947, + -0.010093869641423225, + 0.009466213174164295, + -0.043605536222457886, + -0.17007218301296234, + -0.1366996467113495, + 0.08740171790122986, + -0.014591479673981667, + -0.0031720874831080437, + 0.0835830345749855, + 0.028662094846367836, + -0.21436777710914612, + -0.24753160774707794, + -0.06092096120119095, + 0.03788171336054802, + 0.04295210912823677, + 0.19064708054065704, + 0.3095722496509552, + 0.08003447204828262, + -0.09509303420782089, + -0.05495578795671463, + -0.052218906581401825, + -0.07204427570104599, + 0.07710819691419601, + 0.18033725023269653, + 0.0834946483373642, + -0.049662720412015915, + -0.06561554968357086, + -0.013351643458008766, + -0.11217659711837769, + 0.031957074999809265, + 0.12180440872907639, + 0.06891122460365295, + -0.013705568388104439, + 0.0011150656500831246, + 0.03281388059258461, + -0.11285661906003952, + -0.06422404199838638, + 0.04218210279941559, + 0.014165353029966354, + -0.006244795396924019, + 0.01745765097439289, + 0.08924975246191025, + 0.01710040494799614, + -0.14013372361660004, + -0.21913501620292664, + 0.03613810986280441, + 0.14273521304130554, + 0.05801931768655777, + 0.021427493542432785, + 0.23185034096240997, + 0.2427377849817276, + -0.4384608566761017, + -0.7205182909965515, + -0.18313364684581757, + 0.033575087785720825, + -0.0809125304222107, + 0.04173902049660683, + 0.7251381874084473, + 1.1058244705200195, + -0.015065462328493595, + -0.6434917449951172, + -0.3080260753631592, + -0.090518057346344, + -0.3659006655216217, + -0.4520319700241089, + 0.5924424529075623, + 1.4148176908493042, + 0.5285682082176208, + -0.027211233973503113, + -0.07359065115451813, + -0.08583711832761765, + -0.5631492137908936, + -1.0246236324310303, + -0.1835726648569107, + 0.3307121694087982, + 0.22562064230442047, + 0.05237145721912384, + 0.13263091444969177, + 0.13899636268615723, + -0.1626550555229187, + -0.3918432295322418, + -0.03585565462708473, + 0.06904798001050949, + 0.029870154336094856, + 0.04289601743221283, + 0.05758490040898323, + 0.10055387020111084, + -0.011962685734033585, + -0.13269846141338348, + 0.0012237781193107367, + 0.05511128902435303, + 0.03764793649315834, + -0.07580426335334778, + -0.1750984787940979, + 0.0189101230353117, + 0.08156414330005646, + 0.01691802591085434, + 0.004023027140647173, + 0.18009696900844574, + 0.22744491696357727, + -0.38747039437294006, + -0.6413040161132812, + -0.19208981096744537, + 0.01971367374062538, + -0.036756888031959534, + 0.004946697968989611, + 0.7331712245941162, + 1.1178003549575806, + 0.03220612183213234, + -0.5881579518318176, + -0.24453559517860413, + -0.11856977641582489, + -0.43593257665634155, + -0.5339378118515015, + 0.49467018246650696, + 1.3376370668411255, + 0.5238692164421082, + 0.04584280773997307, + 0.004761924035847187, + -0.032823480665683746, + -0.5419207811355591, + -1.0093209743499756, + -0.19847697019577026, + 0.20687319338321686, + 0.12301573902368546, + 0.07981085777282715, + 0.14125365018844604, + 0.19885297119617462, + -0.1678825318813324, + -0.4042292535305023, + 0.004483209457248449, + 0.03009556047618389, + 0.010802071541547775, + 0.005967534612864256, + 0.0892769992351532, + 0.07342032343149185, + -0.0588892325758934, + -0.09044717997312546, + 0.06307072192430496, + -0.012583961710333824, + -0.006880680099129677, + 0.0030021765269339085, + 0.01633061282336712, + 0.06990820169448853, + 0.0070900083519518375, + -0.03546716272830963, + -0.022131899371743202, + -0.02906683459877968, + 0.010664403438568115, + -0.18731924891471863, + -0.158770352602005, + 0.08571326732635498, + 0.039154618978500366, + 0.032578419893980026, + -0.005781106185168028, + 0.17460086941719055, + 0.2787456810474396, + -0.13416190445423126, + -0.23966801166534424, + -0.004878139588981867, + 0.02796499989926815, + -0.06610933691263199, + -0.19162042438983917, + 0.11163146048784256, + 0.371842622756958, + 0.06444671750068665, + 0.016595548018813133, + 0.01164282951503992, + 0.08330011367797852, + -0.03192862868309021, + -0.2867860198020935, + -0.07080501317977905, + -0.016348646953701973, + -0.06306261569261551, + -0.016291450709104538, + 0.010558445006608963, + 0.13014638423919678, + 0.06202690303325653, + -0.03361419215798378, + 0.0691375732421875, + 0.003561250865459442, + -0.013095442205667496, + -0.050333790481090546, + -0.019117066636681557, + 0.0012089330703020096, + -0.004555183462798595, + -0.022682132199406624, + 0.04747068136930466, + -0.06425238400697708, + -0.0010437731398269534, + -0.0071629988960921764, + -0.04302623122930527, + -0.04830477759242058, + -0.04069536179304123, + -0.06627446413040161, + -0.011470981873571873, + 0.03961857780814171, + 0.026594260707497597, + -0.020662540569901466, + -0.05999285355210304, + -0.053548794239759445, + -0.025959201157093048, + -0.015834785997867584, + 0.013910192996263504, + -0.015868371352553368, + -0.056620921939611435, + -0.06785159558057785, + -0.061030179262161255, + -0.03560228645801544, + -0.04177624359726906, + -0.024657463654875755, + -0.04889696091413498, + 0.004557035397738218, + 0.15414470434188843, + 0.21642963588237762, + 0.035425592213869095, + -0.04339970648288727, + -0.05034525692462921, + -0.08522290736436844, + 0.10652441531419754, + 0.6791198253631592, + 0.7785530686378479, + 0.19941796362400055, + -0.05430706962943077, + -0.02583213709294796, + -0.055139996111392975, + 0.17940561473369598, + 0.6757862567901611, + 0.8240399360656738, + 0.25826773047447205, + -0.062254682183265686, + -0.026456547901034355, + -0.027271386235952377, + -0.0026193747762590647, + 0.11893659085035324, + 0.1915995329618454, + 0.013776157051324844, + 0.08452087640762329, + 0.009950258769094944, + 0.01774573139846325, + 0.06609759479761124, + 0.06512798368930817, + 0.07601971179246902, + 0.09192144125699997, + 0.007696932647377253, + -0.056120894849300385, + -0.03937293961644173, + 0.043086692690849304, + 0.055803027004003525, + 0.08208976686000824, + 0.03658852353692055, + 0.025779196992516518, + -0.0340605266392231, + 0.03186321631073952, + 0.09720855951309204, + 0.10651290416717529, + 0.09562067687511444, + 0.08120692521333694, + 0.06832587718963623, + 0.03940538689494133, + 0.09561086446046829, + -0.03726261481642723, + -0.3520663380622864, + -0.4187469184398651, + -0.11643502116203308, + 0.06203937157988548, + 0.056670401245355606, + 0.11540547758340836, + -0.2742924690246582, + -1.1301417350769043, + -1.2482489347457886, + -0.4411431849002838, + 0.08538330346345901, + 0.036888301372528076, + 0.08759869635105133, + -0.32129940390586853, + -1.1163593530654907, + -1.26430082321167, + -0.48638999462127686, + 0.1056363582611084, + 0.042436979711055756, + 0.07075526565313339, + -0.08341801166534424, + -0.30567145347595215, + -0.39268070459365845, + -0.10187282413244247, + -0.02507772110402584, + -0.0044433241710066795, + -0.009278317913413048, + -0.02964872494339943, + -0.018799586221575737, + -0.03760084509849548, + -0.030454028397798538, + -0.004638439975678921, + 0.026587119325995445, + 0.0095819728448987, + -0.007110759150236845, + -0.006491640582680702, + -0.028083719313144684, + -0.009543413296341896, + -0.005706887226551771, + 0.013012710027396679, + -0.010281933471560478, + -0.0544208325445652, + -0.023230208083987236, + -0.05344587564468384, + -0.04052828997373581, + -0.028035156428813934, + -0.011922319419682026, + -0.045427750796079636, + 0.020700184628367424, + 0.2117788940668106, + 0.21090814471244812, + 0.07214333862066269, + -0.019348343834280968, + -0.014455118216574192, + -0.03561105206608772, + 0.17339389026165009, + 0.49509289860725403, + 0.5219546556472778, + 0.26121678948402405, + -0.029803339391946793, + -0.013761913403868675, + -0.04028521850705147, + 0.17008572816848755, + 0.45583003759384155, + 0.4757367670536041, + 0.22357690334320068, + -0.050064269453287125, + -0.021086007356643677, + -0.039873600006103516, + 0.06433176249265671, + 0.20187893509864807, + 0.2078690379858017, + 0.07802058011293411, + 0.022050827741622925, + -0.05272649601101875, + -0.024311071261763573, + -0.12387345731258392, + -0.20065246522426605, + 0.0262442696839571, + 0.20101603865623474, + 0.056791841983795166, + -0.008266052231192589, + 0.025132112205028534, + -0.23289933800697327, + -0.5296569466590881, + -0.282010018825531, + 0.025113720446825027, + 0.13172000646591187, + 0.16999290883541107, + 0.31588253378868103, + 0.05583454668521881, + -0.5321000814437866, + -0.5585035085678101, + -0.23885560035705566, + 0.0461968369781971, + 0.13807418942451477, + 0.6536149382591248, + 0.6385176777839661, + -0.15636183321475983, + -0.5484278798103333, + -0.5470613241195679, + -0.06269911676645279, + -0.06726553291082382, + 0.5561463236808777, + 1.0985187292099, + 0.6801460385322571, + 0.12841203808784485, + -0.21693651378154755, + -0.19168342649936676, + -0.43073776364326477, + -0.15226863324642181, + 0.41150590777397156, + 0.47421786189079285, + 0.25146934390068054, + -0.017203813418745995, + -0.09694849699735641, + -0.4082376956939697, + -0.3549531400203705, + -0.023591510951519012, + 0.12086013704538345, + 0.08050766587257385, + -0.044960521161556244, + -0.0031193571630865335, + 0.014398006722331047, + -0.005931032355874777, + 0.01548685971647501, + 0.05407734215259552, + -0.006386967841535807, + 0.021660227328538895, + 0.01656133122742176, + 0.002835798542946577, + 0.0008500503608956933, + 0.021802745759487152, + 0.13470955193042755, + 0.06802596151828766, + 0.0033256933093070984, + 0.03037848509848118, + 0.054654810577631, + -0.034221138805150986, + 0.015171438455581665, + 0.23395732045173645, + 0.24771827459335327, + 0.16352902352809906, + -0.07505007833242416, + -0.0814652070403099, + -0.21493901312351227, + -0.3109704852104187, + 0.013416547328233719, + 0.12807825207710266, + 0.12044191360473633, + -0.007915153168141842, + 0.0100772799924016, + -0.15165796875953674, + -0.4013277292251587, + -0.24811144173145294, + -0.06641282886266708, + 0.022568246349692345, + 0.061083581298589706, + 0.09920243173837662, + 0.0695505365729332, + -0.12213064730167389, + -0.12606006860733032, + -0.04593949392437935, + -0.040190644562244415, + 0.03899035230278969, + 0.12688779830932617, + 0.114081971347332, + -0.013348283246159554, + 0.03325115144252777, + 0.007111718878149986, + 0.048056699335575104, + -0.003726312192156911, + 0.05401211231946945, + 0.05355936661362648, + 0.21303032338619232, + 0.2944865822792053, + -0.13604623079299927, + -0.3770989775657654, + -0.0808275118470192, + -0.006103217601776123, + -0.02005188539624214, + 0.37605899572372437, + 0.7776278853416443, + 0.32064270973205566, + -0.23708422482013702, + -0.23380732536315918, + -0.22103570401668549, + -0.45596328377723694, + 0.07213663309812546, + 0.9384943246841431, + 0.8762810230255127, + 0.3557227551937103, + -0.09239326417446136, + -0.25462013483047485, + -0.9858288168907166, + -0.9860153198242188, + 0.2600172162055969, + 0.7731484770774841, + 0.7665594816207886, + 0.14806008338928223, + 0.13109923899173737, + -0.6917864680290222, + -1.580305814743042, + -0.9557210803031921, + -0.16357193887233734, + 0.3189502954483032, + 0.28703632950782776, + 0.5599567890167236, + 0.2459551841020584, + -0.5451022982597351, + -0.6926754713058472, + -0.4368602931499481, + 0.027606861665844917, + 0.025857241824269295, + 0.5376880764961243, + 0.535673975944519, + 0.09012678265571594, + -0.14688564836978912, + -0.1812361180782318, + 0.050619762390851974, + 0.021388273686170578, + -0.05923623591661453, + -0.006538081914186478, + 0.05171535536646843, + -0.051560595631599426, + -0.007643367163836956, + 0.027748188003897667, + 0.0024676925968378782, + -0.008760283701121807, + 0.13039670884609222, + 0.18568934500217438, + 0.06342563778162003, + 0.030788781121373177, + -0.004423442296683788, + -0.041261281818151474, + 0.013299684040248394, + 0.22491391003131866, + 0.27831292152404785, + 0.0883866474032402, + 0.048967570066452026, + 0.0012756097130477428, + -0.03215779736638069, + 0.02710782177746296, + 0.20178261399269104, + 0.22446107864379883, + 0.06052157282829285, + 0.019020315259695053, + 0.02715166099369526, + -0.03146626800298691, + -0.017960363999009132, + 0.11820292472839355, + 0.16114193201065063, + 0.05221821367740631, + -0.02201441302895546, + -0.026308327913284302, + 0.008580431342124939, + -0.02444308064877987, + 0.061380185186862946, + 0.11184953153133392, + 0.006053542252629995, + -0.03248603641986847, + -0.037558719515800476, + 0.01881473697721958, + -0.02349863201379776, + 0.02150980569422245, + 0.09881952404975891, + 0.03962325677275658, + -0.0031782283913344145, + -0.0030868228059262037, + -0.007606725674122572, + -0.06136326491832733, + 0.022755015641450882, + 0.09683670848608017, + 0.0016674631042405963, + 0.01306125894188881, + 0.011335537768900394, + -0.01769089885056019, + 0.005807302892208099, + 0.19103741645812988, + 0.2631426155567169, + 0.10424992442131042, + 0.025223100557923317, + -0.024689532816410065, + -0.03370697423815727, + 0.0512213259935379, + 0.2983294129371643, + 0.37597405910491943, + 0.18788966536521912, + 0.056492965668439865, + -0.006051253993064165, + -0.027141474187374115, + 0.06733105331659317, + 0.29171472787857056, + 0.32160115242004395, + 0.14176633954048157, + 0.008538221009075642, + -0.013039524666965008, + -0.04279422387480736, + 0.03345612436532974, + 0.19111940264701843, + 0.25728005170822144, + 0.09830093383789062, + -0.03371569141745567, + -0.05277566984295845, + -0.0011038694065064192, + -0.013657800853252411, + 0.10037966072559357, + 0.1724642813205719, + 0.04436478391289711, + -0.02240786701440811, + -0.02181128039956093, + 0.019526727497577667, + -0.050060197710990906, + 0.017275504767894745, + 0.07785085588693619, + -0.001727179973386228, + -0.0014453287003561854, + 0.019352080300450325, + -0.003202121239155531, + -0.04241566359996796, + 0.005586653482168913, + 0.06037082523107529, + 0.014115821570158005, + -0.00568200321868062, + 0.018071964383125305, + -0.0007147599244490266, + 0.011219227686524391, + 0.10582104325294495, + 0.15557849407196045, + 0.06189450994133949, + 0.014160261489450932, + 0.00814653467386961, + -0.028064200654625893, + 0.026086319237947464, + 0.1474728286266327, + 0.18273885548114777, + 0.06638553738594055, + 0.019263381138443947, + 0.028977060690522194, + -0.02551555074751377, + 0.01937149092555046, + 0.12000202387571335, + 0.1285850703716278, + 0.047506313771009445, + -0.011383740231394768, + 0.02826755866408348, + -0.009583448991179466, + -0.02093282900750637, + 0.07994058728218079, + 0.0926218256354332, + 0.0318426676094532, + -0.024409465491771698, + 0.020994359627366066, + 0.03295197710394859, + -0.034276511520147324, + 0.037398867309093475, + 0.0794353187084198, + 0.022805212065577507, + 0.0015407208120450377, + 0.013169347308576107, + 0.038584139198064804, + -0.002118688775226474, + 0.03358406573534012, + 0.09085306525230408, + 0.04255761206150055, + 0.010275964625179768, + 0.025351760908961296, + 0.04205995053052902, + 0.1319226324558258, + 0.049708493053913116, + -0.03743802383542061, + -0.04293569549918175, + -0.07646205276250839, + -0.04986324533820152, + 0.15992362797260284, + 0.011027384549379349, + -0.32150742411613464, + -0.3761928677558899, + -0.1654653549194336, + -0.08728181570768356, + 0.044714685529470444, + -0.007500737439841032, + -0.41376256942749023, + -0.6625701189041138, + -0.21809393167495728, + 0.10641554743051529, + 0.09274336695671082, + 0.10189083218574524, + -0.1175118163228035, + -0.2905261516571045, + -0.06248515099287033, + 0.4791955053806305, + 0.49865299463272095, + 0.23415400087833405, + 0.12729482352733612, + -0.05814196541905403, + -0.003843356389552355, + 0.16410382091999054, + 0.40895968675613403, + 0.22034852206707, + 0.021014101803302765, + -0.05658271536231041, + -0.012199933640658855, + 0.034277670085430145, + 0.09565535932779312, + 0.18921032547950745, + 0.010441004298627377, + -0.07427560538053513, + -0.09049694985151291, + -0.00554919708520174, + 0.021386168897151947, + 0.0297325998544693, + 0.06431404501199722, + -0.07367311418056488, + -0.08734254539012909, + -0.059512097388505936, + 0.11382041126489639, + 0.19622667133808136, + 0.02534862980246544, + -0.09704668819904327, + -0.10857658833265305, + -0.10241919010877609, + -0.037928055971860886, + 0.17917697131633759, + -0.0396210141479969, + -0.472421795129776, + -0.5453466176986694, + -0.23921693861484528, + -0.06353127211332321, + 0.033679377287626266, + -0.011634309776127338, + -0.523267924785614, + -0.8400278091430664, + -0.3026646375656128, + 0.17986975610256195, + 0.20296970009803772, + 0.14190459251403809, + -0.12953802943229675, + -0.3968985378742218, + -0.13779792189598083, + 0.548722505569458, + 0.7039015293121338, + 0.4025704264640808, + 0.19535738229751587, + -0.08568660169839859, + -0.0589536651968956, + 0.1868993639945984, + 0.5782724618911743, + 0.43018248677253723, + 0.08876730501651764, + -0.10219226032495499, + -0.04660544916987419, + 0.018129168078303337, + 0.14359626173973083, + 0.3174169361591339, + 0.07668197154998779, + -0.13716676831245422, + -0.2058524489402771, + -0.023707473650574684, + 0.03213014453649521, + 0.06718969345092773, + 0.0917893499135971, + -0.10766899585723877, + -0.206499844789505, + -0.12713390588760376, + -0.03174767270684242, + 0.046395305544137955, + 0.018318502232432365, + -0.002416136907413602, + -0.027143845334649086, + -0.0036621293984353542, + -0.019220896065235138, + 0.05427055433392525, + 0.05058867856860161, + -0.05274957790970802, + -0.11321325600147247, + -0.07062514126300812, + -0.01720590703189373, + -0.00901520811021328, + 0.01746262051165104, + -0.08946436643600464, + -0.2304752618074417, + -0.1021895483136177, + 0.013501768000423908, + 0.029721295461058617, + -0.010094762779772282, + 0.009764805436134338, + -0.06424269080162048, + -0.03032868541777134, + 0.13044297695159912, + 0.12166891992092133, + 0.07157951593399048, + 0.029467372223734856, + -0.03827595338225365, + -0.031337328255176544, + -0.026486340910196304, + 0.05953369289636612, + 0.029497025534510612, + 0.022669093683362007, + -0.01055963709950447, + -0.025020133703947067, + 0.002589448355138302, + 0.017152298241853714, + 0.062067389488220215, + 0.008266719058156013, + 0.00563611788675189, + -0.0044869836419820786, + 0.003065212396904826, + 0.014371387660503387, + 0.013636622577905655, + 0.021183570846915245, + -0.012462744489312172, + -0.02493542619049549, + 0.009652925655245781, + -0.09309647232294083, + -0.09614148736000061, + 0.020278261974453926, + 0.262399286031723, + 0.0025974283926188946, + -0.09532646089792252, + -0.0391894206404686, + -0.003332971129566431, + -0.25919869542121887, + -0.2104814499616623, + 0.5975717306137085, + 0.20378711819648743, + -0.20521192252635956, + 0.005045099183917046, + 0.16707547008991241, + -0.08322134613990784, + -1.1734565496444702, + 0.4060916006565094, + 0.9109339714050293, + -0.22450445592403412, + -0.14085394144058228, + 0.19534644484519958, + 0.6220589280128479, + -1.0614460706710815, + -1.2444484233856201, + 1.1965712308883667, + 0.5032565593719482, + -0.26604175567626953, + -0.13583213090896606, + 0.6453277468681335, + 0.4994892477989197, + -1.7917202711105347, + -0.15182015299797058, + 0.7891079783439636, + 0.10711944103240967, + -0.11587982624769211, + 0.08287231624126434, + 0.7848142981529236, + -0.1764022707939148, + -1.0492321252822876, + 0.15281184017658234, + 0.3100045919418335, + -0.0461110882461071, + -0.06824400275945663, + 0.25544390082359314, + 0.3444065451622009, + -0.3189513683319092, + -0.3503313362598419, + 0.05462741479277611, + -0.041028521955013275, + 0.00624969182536006, + -0.0014677124563604593, + 0.10383514314889908, + -0.03467189520597458, + -0.03946290910243988, + 0.012734192423522472, + -0.003676857566460967, + -0.1616411954164505, + -0.034441810101270676, + 0.34758275747299194, + -0.0017601394793018699, + -0.17407774925231934, + 0.05167992413043976, + 0.12394318729639053, + -0.018228475004434586, + -0.71342533826828, + 0.39672648906707764, + 0.4870489537715912, + -0.27272745966911316, + -0.02687050960958004, + 0.09090551733970642, + 0.46698617935180664, + -0.6089348196983337, + -0.7488552331924438, + 0.8327828645706177, + 0.19947239756584167, + -0.17806877195835114, + -0.09197663515806198, + 0.3198661506175995, + 0.42619431018829346, + -1.1321229934692383, + -0.05452701821923256, + 0.4155597984790802, + -0.001295815804041922, + -0.06596186012029648, + -0.05821318179368973, + 0.4515152871608734, + 0.06321248412132263, + -0.6065720319747925, + 0.10882120579481125, + 0.13767170906066895, + 0.01809641905128956, + -0.070295050740242, + 0.04035783186554909, + 0.22459834814071655, + -0.048405971378088, + -0.14622822403907776, + -0.01119917817413807, + 0.00666345190256834, + 0.04815478250384331, + -0.017866114154458046, + -0.04813665896654129, + -0.02366034686565399, + 0.03589487820863724, + -0.0066519430838525295, + 0.0004148671869188547, + -0.014153627678751945, + 0.04403751716017723, + 0.04098428785800934, + -0.10525348782539368, + -0.0078808031976223, + 0.0444580540060997, + -0.027595041319727898, + 0.010916849598288536, + -0.1390431821346283, + 0.20334453880786896, + -0.006475532427430153, + -0.16053295135498047, + 0.06964287906885147, + -0.025649840012192726, + 0.12622858583927155, + -0.09694403409957886, + -0.09791161119937897, + 0.2617567479610443, + -0.06268735229969025, + -0.03128494322299957, + -0.017743078991770744, + -0.02372320368885994, + 0.2195650041103363, + -0.2456466406583786, + 0.031090563163161278, + 0.010196326300501823, + -0.04323133826255798, + 0.02746250294148922, + -0.079569011926651, + 0.06894756853580475, + 0.11414647102355957, + -0.12175147980451584, + 0.025397513061761856, + 0.006027852185070515, + 0.013360690325498581, + -0.024561991915106773, + -0.10966529697179794, + 0.04913714900612831, + 0.09801583737134933, + 0.00013951699656900018, + -0.03194398432970047, + 0.002382949460297823, + -0.003335593966767192, + 0.023621119558811188, + 0.024585755541920662, + -0.016027197241783142, + -0.02846739999949932, + -0.012949706055223942, + -0.020852699875831604, + -0.016913240775465965, + 0.016088848933577538, + 0.141468346118927, + 0.07285624742507935, + -0.008997173048555851, + -0.033306676894426346, + -0.03418722003698349, + -0.15127411484718323, + -0.047440435737371445, + 0.2687169015407562, + 0.17237843573093414, + 0.03505166247487068, + -0.06994523108005524, + -0.031143782660365105, + -0.3024960458278656, + -0.1552041918039322, + 0.33517369627952576, + 0.28441429138183594, + 0.06471730768680573, + -0.0613982267677784, + -0.02271229960024357, + -0.29379361867904663, + -0.3259792923927307, + 0.16062304377555847, + 0.29220375418663025, + 0.10862076282501221, + -0.005909152328968048, + 0.049116987735033035, + -0.20140305161476135, + -0.3278747797012329, + -0.02566053718328476, + 0.14338354766368866, + 0.006411381531506777, + -0.007274044211953878, + 0.08232597261667252, + -0.04198717698454857, + -0.17330540716648102, + -0.01131037063896656, + 0.08018575608730316, + -0.02374250255525112, + -0.002276432001963258, + 0.00019528658594936132, + -0.024716932326555252, + 0.026509074494242668, + 0.08361849933862686, + 0.012956380844116211, + -0.06030649319291115, + -0.020338360220193863, + -0.03577016666531563, + -0.06858085840940475, + 0.008245388977229595, + 0.25225168466567993, + 0.16135559976100922, + -0.03690743073821068, + -0.09188401699066162, + -0.10410526394844055, + -0.25971388816833496, + -0.07926154136657715, + 0.3933144509792328, + 0.33186599612236023, + 0.059405017644166946, + -0.11824909597635269, + -0.10528354346752167, + -0.4808295667171478, + -0.25224801898002625, + 0.4267246127128601, + 0.4853539764881134, + 0.16933484375476837, + -0.073345847427845, + -0.02648857608437538, + -0.4723232388496399, + -0.4904792010784149, + 0.1938265562057495, + 0.44070878624916077, + 0.22439399361610413, + 0.03877745941281319, + 0.08536087721586227, + -0.31432414054870605, + -0.5158097743988037, + -0.09537900239229202, + 0.20227058231830597, + 0.07895126938819885, + 0.059195615351200104, + 0.14728911221027374, + -0.059377528727054596, + -0.2884902060031891, + -0.12288203090429306, + 0.05220698565244675, + -0.045279599726200104, + 0.019795719534158707, + -0.009819806553423405, + -0.013713877648115158, + 0.0012175077572464943, + 0.03281072899699211, + 0.0017424041870981455, + -0.028847966343164444, + -0.0032059827353805304, + -0.020358575507998466, + 0.0009416870889253914, + -0.007760196924209595, + 0.07921157032251358, + 0.03826644644141197, + -0.02976907789707184, + -0.03300238028168678, + -0.017963968217372894, + -0.055836472660303116, + -0.03299689665436745, + 0.15166012942790985, + 0.06786434352397919, + 0.008589516393840313, + -0.05790036544203758, + -0.0029997669626027346, + -0.14070068299770355, + -0.08799122273921967, + 0.19680362939834595, + 0.14703704416751862, + 0.03569985553622246, + -0.02847554162144661, + 0.03601403906941414, + -0.1339161992073059, + -0.20527805387973785, + 0.1060374304652214, + 0.16269326210021973, + 0.0575268417596817, + 0.0029672966338694096, + 0.018848277628421783, + -0.1029881089925766, + -0.19446833431720734, + -0.055140964686870575, + 0.09632515162229538, + 0.01196608692407608, + 0.01994382217526436, + 0.0030014747753739357, + 0.0029817752074450254, + -0.09395840018987656, + -0.038611751049757004, + 0.03793984279036522, + -0.006295992527157068, + 0.01736803539097309, + -0.0961727425456047, + 0.1318971812725067, + 0.00169672432821244, + 0.02773740515112877, + -0.03737606480717659, + -0.02413480542600155, + -0.07371329516172409, + 0.04465596005320549, + 0.34972262382507324, + 0.269726425409317, + 0.14907677471637726, + -0.15323053300380707, + -0.24987848103046417, + -0.32931339740753174, + 0.05209995433688164, + 0.5192161798477173, + 0.5108750462532043, + 0.2627664804458618, + -0.26889729499816895, + -0.49891141057014465, + -0.5081418752670288, + 0.13535383343696594, + 0.7318623661994934, + 0.7116816639900208, + 0.2973657250404358, + -0.38982102274894714, + -0.7131763100624084, + -0.5916072130203247, + 0.1200462281703949, + 0.7752112746238708, + 0.6947993636131287, + 0.21100594103336334, + -0.5576100945472717, + -0.7797606587409973, + -0.6058254837989807, + 0.08617032319307327, + 0.6432424187660217, + 0.522933304309845, + 0.16018754243850708, + -0.5134027004241943, + -0.6838728189468384, + -0.5088241100311279, + 0.10101393610239029, + 0.4321025311946869, + 0.3330003023147583, + 0.10116448998451233, + -0.2786642014980316, + -0.4134466052055359, + -0.3247438967227936, + 0.009768294170498848, + 0.008712833747267723, + -0.029476309195160866, + 0.007709377445280552, + 0.025279967114329338, + 0.01615188643336296, + 0.01585867628455162, + -0.0031516810413450003, + -0.06462288647890091, + -0.055517926812171936, + -0.013180199079215527, + -0.014849795028567314, + 0.05535515025258064, + 0.04162544384598732, + 0.0022392054088413715, + -0.09408581256866455, + -0.07889631390571594, + -0.032870080322027206, + 0.0382377915084362, + 0.07495865970849991, + 0.08439645916223526, + 0.008036677725613117, + -0.1167779192328453, + -0.10782196372747421, + -0.06854722648859024, + 0.06310252100229263, + 0.09643208235502243, + 0.08629462122917175, + -0.016969647258520126, + -0.10456187278032303, + -0.10410942137241364, + -0.017384463921189308, + 0.03931420296430588, + 0.11296819150447845, + 0.08688211441040039, + -0.018024103716015816, + -0.0985492691397667, + -0.10534191876649857, + 0.016594627872109413, + 0.024613894522190094, + 0.09626104682683945, + 0.056779902428388596, + -0.01314453687518835, + -0.1173979789018631, + -0.07576211541891098, + -0.00741730397567153, + 0.04463285952806473, + 0.06365535408258438, + 0.029472019523382187, + 0.06097950413823128, + -0.0884813666343689, + -0.020469073206186295, + -0.004499382339417934, + 0.006147715728729963, + 0.0061135985888540745, + 0.046618249267339706, + -0.024977274239063263, + -0.2809607684612274, + -0.20776452124118805, + -0.10792756825685501, + 0.10520339012145996, + 0.2195160835981369, + 0.27846819162368774, + -0.0425783209502697, + -0.4539273977279663, + -0.4210258722305298, + -0.24160517752170563, + 0.2377386838197708, + 0.4254952371120453, + 0.40258923172950745, + -0.08894401043653488, + -0.6261403560638428, + -0.6177268624305725, + -0.2941279113292694, + 0.36115866899490356, + 0.6176164746284485, + 0.5170959234237671, + -0.12760992348194122, + -0.6392932534217834, + -0.6288641095161438, + -0.20397846400737762, + 0.4859760105609894, + 0.7283636927604675, + 0.5233575105667114, + -0.08038943260908127, + -0.513219952583313, + -0.4611802101135254, + -0.08622774481773376, + 0.41959214210510254, + 0.6145293116569519, + 0.4252074360847473, + -0.08993257582187653, + -0.3586794435977936, + -0.23889268934726715, + -0.07402873039245605, + 0.2362663745880127, + 0.33187127113342285, + 0.24442552030086517, + -0.10037989169359207, + -0.1200498715043068, + -0.06188809871673584, + 0.009648810140788555, + 0.07703708112239838, + -0.07734857499599457, + -0.16337357461452484, + -0.13160429894924164, + -0.037760209292173386, + 0.10750655829906464, + 0.21975228190422058, + 0.21332265436649323, + 0.1482381671667099, + -0.012174196541309357, + -0.03128019720315933, + 0.06983920931816101, + 0.2055918425321579, + 0.16611628234386444, + 0.20955723524093628, + 0.21407610177993774, + 0.13214662671089172, + 0.01558306161314249, + 0.20919384062290192, + 0.21453723311424255, + 0.10980720072984695, + 0.10323476791381836, + 0.1754676252603531, + 0.16320686042308807, + 0.076839879155159, + 0.2669583261013031, + 0.29500535130500793, + 0.18005967140197754, + 0.14900699257850647, + 0.2337430715560913, + 0.2607984244823456, + -0.08909865468740463, + 0.12383633106946945, + 0.27329200506210327, + 0.2634970247745514, + 0.2298160344362259, + 0.22673286497592926, + 0.1753624528646469, + -0.14258335530757904, + -0.033422429114580154, + 0.09338828176259995, + 0.21975602209568024, + 0.2488732784986496, + 0.21165378391742706, + 0.08514796197414398, + 0.0776415765285492, + -0.028732767328619957, + -0.0827818363904953, + -0.14784079790115356, + -0.06101813539862633, + -0.10570015013217926, + -0.07298385351896286, + -0.03352680057287216, + -0.08094660192728043, + -0.08546923100948334, + -0.025722583755850792, + -0.04828448221087456, + -0.15816760063171387, + -0.22295169532299042, + -0.04976325109601021, + -0.12255501747131348, + -0.04869991913437843, + 0.09818085283041, + 0.2285904735326767, + 0.015187943354249, + -0.19952231645584106, + -0.1415022611618042, + -0.09511925280094147, + 0.10828559100627899, + 0.35640013217926025, + 0.5399265289306641, + 0.3026861250400543, + -0.10532847791910172, + -0.0455780103802681, + -0.09365752339363098, + 0.2482689470052719, + 0.5483031272888184, + 0.6572608947753906, + 0.4098849594593048, + -0.0039499495178461075, + -0.11641024053096771, + -0.22666053473949432, + -0.03133581206202507, + 0.2815704643726349, + 0.3229265809059143, + 0.009749597869813442, + -0.19616934657096863, + -0.05046992748975754, + -0.15597671270370483, + -0.22775587439537048, + -0.14872166514396667, + -0.12174414098262787, + -0.23433859646320343, + -0.238412007689476, + 0.09725375473499298, + 0.08522887527942657, + 0.006490080617368221, + -0.024619178846478462, + 0.07278231531381607, + 0.13406167924404144, + 0.22993306815624237, + 0.10250072181224823, + 0.09119024127721786, + -0.07687287777662277, + -0.1012108325958252, + -0.09500063210725784, + -0.10082961618900299, + 0.09466016292572021, + 0.11299365013837814, + -0.033278487622737885, + -0.20269805192947388, + -0.21449527144432068, + -0.08820098638534546, + -0.18970704078674316, + -0.050536416471004486, + -0.03471578657627106, + -0.13205547630786896, + -0.18150201439857483, + -0.03963223099708557, + 0.13029472529888153, + -0.11594776809215546, + -0.173879474401474, + 0.017406627535820007, + -0.11885572224855423, + -0.06966021656990051, + 0.1687183529138565, + 0.2677668035030365, + -0.020446041598916054, + -0.11710261553525925, + 0.044354867190122604, + -0.10054060816764832, + -0.1287878155708313, + -0.03600803390145302, + -0.03198331966996193, + -0.22372953593730927, + -0.11045534163713455, + 0.22963544726371765, + 0.16736479103565216, + -0.023956498131155968, + -0.0882943719625473, + -0.11904646456241608, + -0.10481738299131393, + 0.083598293364048, + 0.058089643716812134, + -0.04821285232901573, + 0.16764044761657715, + -0.13788309693336487, + -0.1412951946258545, + 0.059633608907461166, + 0.012824267148971558, + -0.03141501545906067, + -0.017422236502170563, + 0.3908282518386841, + -0.31520241498947144, + -0.27876099944114685, + 0.17109407484531403, + 0.011913848109543324, + -0.04440265893936157, + 0.05610174685716629, + 0.5290316343307495, + -0.4506116211414337, + -0.2946499288082123, + 0.2802693545818329, + 0.04180249199271202, + -0.05673402547836304, + 0.0445592887699604, + 0.4933576285839081, + -0.4903600513935089, + -0.3259376883506775, + 0.26069584488868713, + 0.047843094915151596, + -0.053804315626621246, + 0.029928382486104965, + 0.3588394224643707, + -0.39090782403945923, + -0.18598265945911407, + 0.1703576147556305, + 0.010407418012619019, + 0.019840527325868607, + -0.017079327255487442, + 0.21012797951698303, + -0.1586841642856598, + -0.12738685309886932, + 0.12431345880031586, + 0.028149213641881943, + 0.05083676427602768, + -0.07053223252296448, + 0.12090320140123367, + -0.13737183809280396, + -0.09807822853326797, + 0.07203921675682068, + -0.01965559460222721, + 0.036479320377111435, + -0.02657422423362732, + 0.2924504280090332, + -0.19397024810314178, + -0.20908842980861664, + 0.07435549795627594, + 0.011985386721789837, + -0.051603686064481735, + 0.039122600108385086, + 0.5911946892738342, + -0.45937344431877136, + -0.43863579630851746, + 0.23180224001407623, + 0.05592876672744751, + -0.10227655619382858, + 0.1371937245130539, + 0.7193072438240051, + -0.6789532899856567, + -0.5275344252586365, + 0.4098500609397888, + 0.09136661887168884, + -0.08802130073308945, + 0.12226735055446625, + 0.6819202303886414, + -0.7316576838493347, + -0.5229181051254272, + 0.37578293681144714, + 0.09086397290229797, + -0.05128701403737068, + 0.09287497401237488, + 0.5103837251663208, + -0.6150248646736145, + -0.3208717107772827, + 0.29780012369155884, + 0.071808360517025, + 0.04605705663561821, + 0.028153980150818825, + 0.30872926115989685, + -0.32211968302726746, + -0.1925150454044342, + 0.18948692083358765, + 0.07391810417175293, + 0.08546463400125504, + -0.07042243331670761, + 0.14390304684638977, + -0.22509464621543884, + -0.12615789473056793, + 0.09681600332260132, + 0.0030679223127663136, + 0.06206878274679184, + -0.0493885837495327, + 0.11675205081701279, + -0.09476804733276367, + -0.0708041712641716, + 0.027848264202475548, + 0.018535451963543892, + 0.01112216804176569, + -0.023546719923615456, + 0.2808285057544708, + -0.2312571257352829, + -0.16320407390594482, + 0.15229304134845734, + -0.007220278959721327, + -0.026767488569021225, + -0.008487970568239689, + 0.39064091444015503, + -0.3746477961540222, + -0.22930599749088287, + 0.23297259211540222, + -0.020648201927542686, + -0.03918099403381348, + -0.03193120285868645, + 0.37857353687286377, + -0.38306936621665955, + -0.25103962421417236, + 0.2414209097623825, + 0.007709929719567299, + -0.041483473032712936, + -0.001570625347085297, + 0.315625935792923, + -0.276553213596344, + -0.13154125213623047, + 0.17517149448394775, + 0.03219839558005333, + 0.002647437620908022, + -0.012777225114405155, + 0.17064248025417328, + -0.13943275809288025, + -0.10204917937517166, + 0.09418098628520966, + 0.026260169222950935, + 0.05167905613780022, + -0.024634944275021553, + 0.0931941494345665, + -0.11875593662261963, + -0.0752263143658638, + 0.0569780170917511, + 0.00024334408226422966, + -0.001991289434954524, + -0.012094452045857906, + -0.0012201170902699232, + 0.01342268567532301, + 0.006425719242542982, + 0.01147665549069643, + -0.002208880614489317, + -0.019385183230042458, + -0.024868011474609375, + 0.00465290667489171, + 0.009205960668623447, + 0.0016242304118350148, + 0.0059639886021614075, + -0.03436571732163429, + 0.01672518253326416, + 0.008815832436084747, + 0.06389293074607849, + 0.06249547377228737, + 0.06542838364839554, + 0.043118152767419815, + 0.04117512330412865, + 0.014435848221182823, + 0.0065850247628986835, + 0.03811212629079819, + -0.006077034864574671, + -0.004025861620903015, + 0.006247953977435827, + 0.014478449709713459, + 0.0009701942908577621, + -0.002422194229438901, + 0.009390920400619507, + -0.052253514528274536, + -0.05192738026380539, + -0.010346310213208199, + -0.001328076352365315, + -0.002972622634842992, + 0.0015572139527648687, + 0.022503724321722984, + -0.002475353656336665, + 0.001927886507473886, + 0.02994818612933159, + 0.02062363363802433, + -0.0010653833160176873, + -0.005995174869894981, + 0.024450020864605904, + 0.013005194254219532, + 0.0496530681848526, + 0.029475165531039238, + 0.004157512914389372, + -0.0007043799851089716, + 0.01860312558710575, + 0.03839566186070442, + 0.00014980587002355605, + 0.018569663166999817, + 0.05668198689818382, + 0.04645680636167526, + 0.01642409712076187, + 0.03577466681599617, + 0.03575601801276207, + -0.03680748492479324, + -0.01865880750119686, + 0.041660092771053314, + 0.033268485218286514, + 0.03338993713259697, + 0.04665865749120712, + -0.03322917968034744, + -0.2860279381275177, + -0.28877392411231995, + -0.09617949277162552, + 0.014234350994229317, + 0.038012001663446426, + -0.016850680112838745, + -0.27252569794654846, + -0.6714493632316589, + -0.686245322227478, + -0.3376169502735138, + -0.0812990590929985, + 0.003058002796024084, + -0.026376569643616676, + -0.29216718673706055, + -0.6779875159263611, + -0.6917123198509216, + -0.3184400796890259, + -0.058261968195438385, + 0.06338769942522049, + 0.03199980780482292, + -0.09837217628955841, + -0.3355932831764221, + -0.30900436639785767, + -0.04878076910972595, + 0.061543505638837814, + 0.04651529714465141, + 0.0263908002525568, + 0.0030237447936087847, + -0.10458099842071533, + -0.07959774881601334, + 0.05430716276168823, + 0.056767694652080536, + 0.00796051137149334, + -0.016737859696149826, + -0.042338743805885315, + -0.0198048185557127, + -0.03085070475935936, + -0.058721307665109634, + -0.036032311618328094, + -0.0035414688754826784, + -8.359456842299551e-05, + -0.02213932015001774, + 0.02032857947051525, + 0.021788733080029488, + -0.03522418439388275, + -0.025317413732409477, + -0.042937491089105606, + -0.05680134892463684, + -0.012510996311903, + 0.226289302110672, + 0.24401520192623138, + 0.022300971671938896, + -0.030825607478618622, + -0.05485948920249939, + 0.007590078748762608, + 0.2208130657672882, + 0.6964298486709595, + 0.7457719445228577, + 0.3470557630062103, + 0.06941442936658859, + -0.03543366119265556, + 0.035853609442710876, + 0.2872598171234131, + 0.7504303455352783, + 0.7509996294975281, + 0.34327855706214905, + 0.024429334327578545, + -0.05711393058300018, + -0.034500252455472946, + 0.057939525693655014, + 0.33292675018310547, + 0.3141649067401886, + 0.033748809248209, + -0.062175147235393524, + -0.041224412620067596, + -0.01891348883509636, + -0.014519350603222847, + 0.08635713160037994, + 0.03148616850376129, + -0.08749162405729294, + -0.05658482387661934, + 0.00018510188965592533, + 0.002624311950057745, + -0.003570129396393895, + 0.0067627751268446445, + -0.01349653396755457, + -0.003961967770010233, + 0.0034001911990344524, + -0.00385954394005239, + 0.018012456595897675, + -0.018755480647087097, + -0.03163064643740654, + -0.0035233700182288885, + 0.011690095998346806, + -0.014693490229547024, + 0.017746854573488235, + 0.05693097040057182, + 0.1272590607404709, + 0.23477119207382202, + 0.19823509454727173, + 0.05071045830845833, + -0.007188393268734217, + -0.05571149289608002, + -0.06468938291072845, + -0.017831332981586456, + -0.07572834193706512, + -0.19599483907222748, + -0.15608063340187073, + -0.039450764656066895, + -0.035583946853876114, + -0.1605951488018036, + -0.5041624307632446, + -0.6836286783218384, + -0.3773191571235657, + -0.08623629808425903, + -0.04881078004837036, + 0.029403403401374817, + 0.15516817569732666, + 0.4108496308326721, + 0.6393839716911316, + 0.4688946008682251, + 0.2135964334011078, + 0.0623941570520401, + 0.02426956780254841, + -8.065254223765805e-05, + -0.00816427543759346, + -0.09353788942098618, + -0.06872912496328354, + -0.029405562207102776, + 0.012364620342850685, + 0.0060868943110108376, + 0.017015695571899414, + -0.0076495204120874405, + -0.006090708542615175, + -0.016521835699677467, + 0.009218892082571983, + 0.030833140015602112, + -0.0002345978282392025, + 0.03332215175032616, + 0.0030349211301654577, + 0.009600857272744179, + 0.05706647038459778, + 0.06095677986741066, + -0.016137542203068733, + 0.03195658698678017, + 0.13535599410533905, + 0.28229761123657227, + 0.4573267698287964, + 0.39102476835250854, + 0.17547546327114105, + 0.005337159149348736, + -0.07699840515851974, + -0.12667469680309296, + -0.16613735258579254, + -0.2908898890018463, + -0.44942277669906616, + -0.34229782223701477, + -0.16225378215312958, + -0.1100199744105339, + -0.4044281840324402, + -0.9058251976966858, + -1.1549302339553833, + -0.7502554059028625, + -0.2716369032859802, + -0.13495275378227234, + 0.08614412695169449, + 0.3164423108100891, + 0.7155097723007202, + 1.0356683731079102, + 0.7939887642860413, + 0.39567017555236816, + 0.16957539319992065, + 0.02675812318921089, + 0.048314403742551804, + 0.053107086569070816, + -0.009243623353540897, + -0.011442561633884907, + 0.004911235999315977, + 0.012210517190396786, + 0.006660772021859884, + -0.004562888294458389, + -0.009606098756194115, + -0.01610635593533516, + -0.03475078567862511, + 0.007796770427376032, + 0.02015513926744461, + 0.020311446860432625, + 0.009043446741998196, + -0.01929326355457306, + -0.04183953255414963, + -0.003052672604098916, + 0.020744286477565765, + 0.01371331699192524, + 0.004048139322549105, + 0.0692848190665245, + 0.16867054998874664, + 0.2799474000930786, + 0.28119951486587524, + 0.13579942286014557, + -0.0015732255997136235, + -0.05406518653035164, + -0.05831173434853554, + -0.034435681998729706, + -0.11925295740365982, + -0.2570647895336151, + -0.19120880961418152, + -0.09981344640254974, + -0.011702792719006538, + -0.22477947175502777, + -0.5395713448524475, + -0.7111374139785767, + -0.4207299053668976, + -0.11811137199401855, + -0.035199034959077835, + 0.024358956143260002, + 0.16262274980545044, + 0.46769100427627563, + 0.677872896194458, + 0.4637402892112732, + 0.15558630228042603, + 0.04467496648430824, + 0.03221412003040314, + 0.02430277317762375, + -0.006398700177669525, + -0.07235423475503922, + -0.03669704124331474, + -0.000992153538390994, + 0.02220241352915764, + -0.03329842537641525, + 0.05199713259935379, + -0.14053553342819214, + 0.1906905472278595, + -0.13544943928718567, + 0.08535720407962799, + -0.009813228622078896, + 0.03578176349401474, + -0.05863757058978081, + 0.33848440647125244, + -0.49837300181388855, + 0.15308170020580292, + 0.14865124225616455, + -0.12349266558885574, + -0.025796135887503624, + 0.17790427803993225, + -0.7813658714294434, + 0.853188693523407, + 0.2489670068025589, + -0.7378701567649841, + 0.2207188457250595, + 0.05207442864775658, + -0.4280349314212799, + 1.1408430337905884, + -0.24505679309368134, + -1.5490919351577759, + 1.4560288190841675, + -0.31143030524253845, + -0.03536878153681755, + 0.5640448331832886, + -0.6874421834945679, + -1.210310697555542, + 2.6637399196624756, + -1.6589887142181396, + 0.2221546173095703, + 0.10179737955331802, + -0.4354941248893738, + 0.034149203449487686, + 1.480568528175354, + -2.072199821472168, + 0.9205833673477173, + 0.021510563790798187, + -0.07755836099386215, + 0.17983688414096832, + 0.040537625551223755, + -0.5325585603713989, + 0.550999641418457, + -0.11060550063848495, + -0.09052976220846176, + -0.048361390829086304, + 0.03450514376163483, + -0.11854307353496552, + 0.23462797701358795, + -0.17563995718955994, + 0.0653814822435379, + -0.009748813696205616, + 0.07013920694589615, + -0.08628369867801666, + 0.3019683063030243, + -0.630340576171875, + 0.274477481842041, + 0.15417183935642242, + -0.036220982670784, + -0.07344137132167816, + 0.2339126616716385, + -1.0395091772079468, + 1.2002928256988525, + 0.085142120718956, + -0.7080597281455994, + 0.23101751506328583, + 0.016307154670357704, + -0.45877355337142944, + 1.617128849029541, + -0.6593433618545532, + -1.8957709074020386, + 1.746606469154358, + -0.37062564492225647, + 0.01213759370148182, + 0.5851964354515076, + -1.0307577848434448, + -1.4803766012191772, + 3.812014102935791, + -2.0028398036956787, + 0.12008816003799438, + 0.01813559979200363, + -0.5065457820892334, + 0.17598780989646912, + 2.0418734550476074, + -2.680522918701172, + 0.7466094493865967, + 0.16271913051605225, + -0.04379571974277496, + 0.21930621564388275, + 0.041255541145801544, + -0.6644601821899414, + 0.481300413608551, + 0.05410065874457359, + -0.09025495499372482, + 0.01954805478453636, + 0.01899997517466545, + -0.1337241530418396, + 0.19821906089782715, + -0.06395180523395538, + -0.03586877882480621, + 0.01973363384604454, + 0.013873124495148659, + -0.09288538247346878, + 0.4300728440284729, + -0.4235192537307739, + 0.03646458685398102, + 0.10077393800020218, + -0.07569073140621185, + -0.08176662772893906, + 0.3834531605243683, + -0.747482419013977, + 0.4493187367916107, + 0.2960513234138489, + -0.5245057344436646, + 0.27831950783729553, + 0.0731748417019844, + -0.45574328303337097, + 0.6987965703010559, + 0.019539732486009598, + -1.1160184144973755, + 1.0756875276565552, + -0.3804619312286377, + -0.040626902133226395, + 0.2780243456363678, + -0.32946258783340454, + -0.8122196793556213, + 1.9535348415374756, + -1.300661563873291, + 0.3443142771720886, + 0.04858396574854851, + -0.17409801483154297, + -0.07783844321966171, + 1.0875797271728516, + -1.5148566961288452, + 0.8014272451400757, + -0.19643208384513855, + -0.033590562641620636, + 0.11178025603294373, + 0.08284300565719604, + -0.5165408849716187, + 0.5841389894485474, + -0.24739950895309448, + 0.027926180511713028, + -0.028708497062325478, + 0.0037401756271719933, + -0.0047450135461986065, + 0.008427698165178299, + 0.009801353327929974, + -0.0029346586670726538, + -0.010193527676165104, + 0.014876358211040497, + 0.009861295111477375, + -0.005554665345698595, + -0.06270359456539154, + -0.0316256619989872, + 0.006706684362143278, + 0.04316525161266327, + 0.008637072518467903, + -0.03666357323527336, + -0.0719730481505394, + -0.1525861918926239, + -0.14396126568317413, + -0.05387119948863983, + 0.01955549605190754, + 0.007112634833902121, + -0.05175568535923958, + -0.16772602498531342, + -0.20807777345180511, + -0.18768996000289917, + -0.17093753814697266, + -0.03334345668554306, + 0.0011808606795966625, + -0.01579100452363491, + -0.12589050829410553, + -0.17219413816928864, + -0.19648219645023346, + -0.21980451047420502, + -0.04920821264386177, + 0.0012217299081385136, + 0.023885242640972137, + -0.056074876338243484, + -0.13907776772975922, + -0.19139252603054047, + -0.13652737438678741, + -0.0027339402586221695, + 0.004720518831163645, + -0.00037206560955382884, + 0.017924504354596138, + -0.02118082158267498, + -0.06553903222084045, + -0.0435921773314476, + 0.02721239998936653, + 0.020702000707387924, + 0.024033410474658012, + 0.005382229574024677, + -0.01273527555167675, + -0.01742861233651638, + 0.007402990944683552, + 0.010333286598324776, + 0.02598601020872593, + 0.012456837110221386, + -0.03471057116985321, + -0.10051856189966202, + -0.08084382116794586, + -0.023420603945851326, + 0.031205907464027405, + 0.00424322672188282, + -0.03734385594725609, + -0.1152661070227623, + -0.2012551724910736, + -0.1995576024055481, + -0.07972321659326553, + -0.011126434430480003, + -0.0185835100710392, + -0.06944561004638672, + -0.21481844782829285, + -0.26795628666877747, + -0.24916253983974457, + -0.17833945155143738, + -0.06658200174570084, + -0.00305415247566998, + -0.054028186947107315, + -0.19072681665420532, + -0.256619930267334, + -0.26868295669555664, + -0.21621295809745789, + -0.06564134359359741, + 0.0031192339956760406, + 0.013205861672759056, + -0.08044812828302383, + -0.18137820065021515, + -0.23007699847221375, + -0.13054916262626648, + -0.01135951280593872, + 0.013734308071434498, + 0.010981118306517601, + -0.02249351143836975, + -0.05804377421736717, + -0.10652261227369308, + -0.04163172468543053, + 0.017101088538765907, + -0.028687385842204094, + -0.0019976652693003416, + 0.009987232275307178, + 0.010130539536476135, + 0.0015575449215248227, + -0.000983694102615118, + -0.012845008634030819, + 0.01329281460493803, + 0.0029350779950618744, + -0.003755913581699133, + -0.036475058645009995, + -0.0245466697961092, + -0.0020879909861832857, + 0.025867130607366562, + -0.0065954397432506084, + 0.008656582795083523, + -0.04037104919552803, + -0.11718368530273438, + -0.13506115972995758, + -0.024255141615867615, + 0.014097613282501698, + -0.0009370348998345435, + -0.010953565128147602, + -0.12869219481945038, + -0.18789908289909363, + -0.19098156690597534, + -0.12795749306678772, + -0.002666366985067725, + -0.004907527007162571, + -0.014610078185796738, + -0.11913872510194778, + -0.19921070337295532, + -0.21869640052318573, + -0.1849898099899292, + -0.03470952808856964, + 0.0064156935550272465, + 0.03401843458414078, + -0.04000416398048401, + -0.12354391813278198, + -0.16908879578113556, + -0.10385500639677048, + 0.002833302365615964, + -0.036176733672618866, + -0.001048827893100679, + 0.010002595372498035, + -0.020798830315470695, + -0.0488261841237545, + -0.002972641494125128, + 0.016395021229982376, + -0.045770127326250076, + -0.12710650265216827, + -0.1637774109840393, + -0.1411965787410736, + 0.20447289943695068, + 0.509396493434906, + 0.07264503091573715, + 0.12041529268026352, + -0.015143441036343575, + -0.2673257887363434, + -0.3589763641357422, + 0.11289574205875397, + 0.8517020344734192, + 0.7068799138069153, + 0.067301444709301, + -0.02102830447256565, + -0.5235708355903625, + -1.2064802646636963, + -0.856619656085968, + 0.26774707436561584, + 0.6825867295265198, + 0.13516077399253845, + 0.3054035007953644, + -0.0727991834282875, + -1.4912222623825073, + -1.906838297843933, + -0.8574200868606567, + -0.15282419323921204, + 0.39327505230903625, + 0.9758505821228027, + 1.2323224544525146, + 0.18179064989089966, + -0.947610080242157, + -0.6657719016075134, + -0.19935055077075958, + -0.09150458872318268, + 0.34379544854164124, + 1.2025749683380127, + 0.9517407417297363, + -0.12023784220218658, + -0.3146151900291443, + -0.1049022302031517, + -0.34867578744888306, + -0.32945582270622253, + 0.28920575976371765, + 0.7844374179840088, + 0.35520124435424805, + 0.007452746387571096, + 0.018862545490264893, + -0.0021927610505372286, + 0.0321974977850914, + 0.05439181253314018, + -0.030729038640856743, + -0.03517322614789009, + -0.037830010056495667, + -0.056672073900699615, + -0.017769837751984596, + 0.06385952979326248, + 0.08161566406488419, + 0.07809178531169891, + 0.06333671510219574, + -0.036322008818387985, + -0.06432312726974487, + -0.03629852458834648, + 0.010879911482334137, + 0.088901087641716, + 0.0021402277052402496, + 0.09618857502937317, + 0.02661084569990635, + -0.03414442762732506, + -0.08736730366945267, + -0.048222169280052185, + 0.03507986292243004, + -0.053828027099370956, + 0.006044292356818914, + 0.04232194274663925, + 0.001624415279366076, + -0.028371643275022507, + -0.08724038302898407, + -0.005835397634655237, + 0.01057528518140316, + 0.04210871085524559, + 0.06106603890657425, + 0.04250370338559151, + 0.0028668276499956846, + -0.07583706080913544, + -0.06849333643913269, + -0.08538331836462021, + -0.021475542336702347, + 0.044341571629047394, + 0.03604369983077049, + 0.05146002024412155, + 0.00280605535954237, + -0.004615028854459524, + -0.07857430726289749, + -0.03716180846095085, + 0.010876243002712727, + -0.03418488800525665, + 0.007391764782369137, + 0.05969953536987305, + 0.08769611269235611, + 0.066011443734169, + -0.10404568910598755, + -0.27194535732269287, + -0.05224551260471344, + -0.03618992492556572, + -0.023098375648260117, + 0.13832588493824005, + 0.21510572731494904, + -0.07285867631435394, + -0.489085853099823, + -0.33285844326019287, + -0.04830349236726761, + 0.014211038127541542, + 0.2612524926662445, + 0.6911754608154297, + 0.5294638276100159, + -0.2706173360347748, + -0.39350029826164246, + -0.05156399682164192, + -0.16490484774112701, + 0.1161464974284172, + 0.8029336929321289, + 1.1809980869293213, + 0.5025736689567566, + 0.07084998488426208, + -0.1901131123304367, + -0.4918227195739746, + -0.603122889995575, + -0.09460704773664474, + 0.5786081552505493, + 0.35392242670059204, + 0.1328991800546646, + -0.008106965571641922, + -0.2159435749053955, + -0.6369062662124634, + -0.5241336822509766, + 0.06276796758174896, + 0.1139409989118576, + 0.05483332276344299, + 0.1703934520483017, + 0.14603517949581146, + -0.16187912225723267, + -0.4139055907726288, + -0.14918148517608643, + -0.06163417547941208, + 0.005302567034959793, + 0.015524876303970814, + -0.11895350366830826, + -0.19724233448505402, + 0.03412429615855217, + 0.10862118750810623, + 0.08550503104925156, + -0.008599682711064816, + -0.03031114675104618, + -0.33224624395370483, + -0.27994298934936523, + 0.196475550532341, + 0.31109708547592163, + 0.17151644825935364, + -0.04994147643446922, + -0.167176753282547, + -0.5247878432273865, + -0.21136601269245148, + 0.54701828956604, + 0.6110883951187134, + 0.04194486886262894, + -0.27640673518180847, + -0.0795169249176979, + -0.360530287027359, + 0.3472684621810913, + 1.5428175926208496, + 1.0249378681182861, + -0.2724844515323639, + -0.3013695478439331, + 0.020736562088131905, + -0.019495302811264992, + 0.7758124470710754, + 1.5381159782409668, + 0.028625331819057465, + -1.289720892906189, + -0.5894255638122559, + 0.0526396706700325, + 0.11443997919559479, + 0.5935031771659851, + 0.47169724106788635, + -1.2507063150405884, + -1.351940631866455, + -0.03894977271556854, + 0.05095001682639122, + 0.01581231690943241, + 0.11137383431196213, + -0.22327138483524323, + -0.9629225730895996, + -0.2607772946357727, + 0.5907121300697327, + 0.006906076334416866, + 0.002633580705150962, + 0.01940075121819973, + 0.0143396882340312, + 0.020781584084033966, + -0.07249777764081955, + -0.016355905681848526, + 0.016553230583667755, + -0.027528395876288414, + 0.0244428887963295, + 0.024910561740398407, + 0.027229825034737587, + -0.04104151204228401, + 0.007100561633706093, + 0.0157785601913929, + -0.06626633554697037, + 0.006520191207528114, + 0.021171070635318756, + 0.036674920469522476, + -0.06950324773788452, + -0.03003627620637417, + 2.178798422391992e-05, + -0.07278106361627579, + 0.014382920227944851, + 0.0982266515493393, + 0.1454961597919464, + -0.10096189379692078, + 0.022237209603190422, + -0.00040665315464138985, + -0.013766243122518063, + 0.06440296769142151, + 0.21751047670841217, + 0.02519127167761326, + -0.23383572697639465, + 0.0038903038948774338, + -0.042271602898836136, + -0.012596859596669674, + 0.023778460919857025, + 0.07685687392950058, + -0.21480663120746613, + -0.19205358624458313, + 0.04876565560698509, + -0.016765035688877106, + -0.02620583213865757, + 0.01641852967441082, + 0.02201787941157818, + -0.07457322627305984, + -0.003633625339716673, + 0.07550841569900513, + 0.024774253368377686, + 0.04710151255130768, + 0.09110233932733536, + -0.017366377636790276, + -0.04366954043507576, + -0.039786458015441895, + 0.005311290733516216, + 0.037867460399866104, + 0.05367766693234444, + 0.07434491813182831, + -0.07251215726137161, + -0.04231821000576019, + -0.023427855223417282, + 0.036294277757406235, + 0.07782749086618423, + 0.11835407465696335, + 0.08753973245620728, + -0.20742319524288177, + -0.13341759145259857, + -0.008225077763199806, + 0.07292432337999344, + 0.006392402108758688, + 0.021914338693022728, + -0.09218581020832062, + -0.44192466139793396, + -0.1744878888130188, + 0.014938815496861935, + 0.10678526759147644, + -0.012087192386388779, + -0.024533385410904884, + -0.1804407387971878, + -0.3253834545612335, + 0.040678758174180984, + 0.2011708915233612, + 0.17262929677963257, + -0.0045212251134216785, + -0.033313386142253876, + -0.10575363039970398, + -0.07636692374944687, + 0.20343273878097534, + 0.28330928087234497, + 0.043149981647729874, + -0.01109551265835762, + -0.0027725452091544867, + 0.003926735837012529, + 0.029440222308039665, + 0.23945140838623047, + 0.09122566133737564, + -0.15140119194984436, + 0.08737201988697052, + 0.07120998948812485, + 0.05722665786743164, + -0.04388495534658432, + 0.02116825245320797, + 0.023315919563174248, + 0.10898162424564362, + 0.11808467656373978, + 0.03412344306707382, + 0.002771642990410328, + -0.1959579437971115, + -0.05181330814957619, + -0.0044630044139921665, + 0.12481725960969925, + 0.09140311926603317, + 0.03444851189851761, + -0.10931172221899033, + -0.3204459846019745, + -0.21193139255046844, + -0.11101037263870239, + 0.04186606407165527, + -0.07420916110277176, + -0.2004990428686142, + -0.26937955617904663, + -0.12928874790668488, + 0.20819628238677979, + -0.17379426956176758, + -0.2181481271982193, + 0.005387924611568451, + -0.24132733047008514, + -0.23942433297634125, + 0.41489261388778687, + 1.0702778100967407, + 0.024913936853408813, + -0.28405970335006714, + 0.083008773624897, + -0.11059781163930893, + -0.17623695731163025, + -0.17386195063591003, + 0.010644182562828064, + -0.32716259360313416, + -0.2135595828294754, + 0.1223129853606224, + 0.07060510665178299, + -0.048680394887924194, + -0.3332099914550781, + -0.25886017084121704, + -0.18619979918003082, + -0.00733158877119422, + 0.03393476828932762, + -0.010564662516117096, + -0.01817108877003193, + -0.05650597810745239, + -0.01891104131937027, + -0.0554141066968441, + -0.004592927638441324, + -0.0013615720672532916, + -0.05552899092435837, + -0.0560498908162117, + -0.1080632209777832, + -0.013965745456516743, + -0.03290533646941185, + -0.02599845454096794, + -0.02877708151936531, + -0.05670137703418732, + -0.07158109545707703, + -0.08808472007513046, + -0.03919175639748573, + -0.08478893339633942, + -0.08045543730258942, + -0.10066724568605423, + -0.048338882625103, + -0.06750114262104034, + 0.08164039999246597, + 0.3343777060508728, + 0.004952755756676197, + -0.14891156554222107, + 0.032855477184057236, + -0.03277512267231941, + 0.0474768728017807, + 0.6316664814949036, + 1.2214386463165283, + 0.2548498213291168, + -0.13185030221939087, + -0.018188906833529472, + -0.07653989642858505, + -0.01643386110663414, + 0.06630122661590576, + 0.23864209651947021, + -0.013703612610697746, + -0.09347789734601974, + -0.0900193303823471, + -0.04930814355611801, + -0.02791711315512657, + -0.15441712737083435, + -0.01623091846704483, + -0.0447690524160862, + -0.06071227043867111, + -0.04737209901213646, + -0.059769801795482635, + -0.04375007003545761, + -0.00650476710870862, + 0.021540174260735512, + -0.05590728670358658, + -0.13030850887298584, + -0.022067781537771225, + -0.05066747963428497, + 0.00609770929440856, + 0.108611099421978, + 0.1621929407119751, + 0.05232185125350952, + -0.049729123711586, + -0.11906369775533676, + -0.030973592773079872, + 0.057787079364061356, + 0.1610448956489563, + 0.18756121397018433, + 0.07277501374483109, + -0.05777435004711151, + -0.05227195844054222, + 0.14434091746807098, + 0.1889694482088089, + 0.26951169967651367, + 0.4710105359554291, + 0.2164669781923294, + 0.05052375793457031, + -0.0038236663676798344, + 0.20267778635025024, + 0.31214746832847595, + 0.7506387829780579, + 1.2302387952804565, + 0.4363090693950653, + 0.16759593784809113, + -0.049752235412597656, + 0.044786907732486725, + 0.14537742733955383, + 0.2227499932050705, + 0.37362414598464966, + 0.16590620577335358, + 0.0864599421620369, + -0.14058542251586914, + -0.04404178634285927, + -0.0325944609940052, + -0.019113417714834213, + 0.17414243519306183, + 0.11160623282194138, + -0.034911543130874634, + 0.1523953527212143, + 0.04554234445095062, + -0.054958827793598175, + -0.11794494092464447, + -0.19570015370845795, + -0.21358126401901245, + -0.1885669231414795, + -0.08286706358194351, + -0.29818814992904663, + -0.52330082654953, + -0.6190353631973267, + -0.682529091835022, + -0.6171367764472961, + -0.4793100655078888, + -0.11180876195430756, + -0.3490432798862457, + -0.5531057715415955, + -0.6426181793212891, + -0.6420838832855225, + -0.4970071613788605, + -0.27038174867630005, + -0.09740017354488373, + -0.1929621547460556, + -0.30848363041877747, + -0.27204805612564087, + -0.2515120208263397, + -0.07497832179069519, + 0.03551386669278145, + -0.05060403421521187, + 0.08276989310979843, + 0.14321963489055634, + 0.3583574593067169, + 0.40667927265167236, + 0.39398193359375, + 0.27561235427856445, + 0.005085935816168785, + 0.2793635427951813, + 0.48155927658081055, + 0.7088037729263306, + 0.7394692897796631, + 0.6158861517906189, + 0.3986552655696869, + 0.025508087128400803, + 0.38533228635787964, + 0.5305332541465759, + 0.6659612059593201, + 0.6396889090538025, + 0.5396444797515869, + 0.39010515809059143, + -0.03072960674762726, + 0.014305810444056988, + 0.029885446652770042, + 0.038084372878074646, + 0.012448564171791077, + 0.034353457391262054, + 0.048626724630594254, + 0.048866890370845795, + 0.07561437785625458, + 0.09152165800333023, + 0.08432324975728989, + 0.09332144260406494, + 0.07517607510089874, + 0.049146559089422226, + 0.03146318346261978, + 0.06335246562957764, + 0.06438779830932617, + 0.06851581484079361, + 0.09263566881418228, + 0.06460423022508621, + 0.011992924846708775, + 0.03396693989634514, + 0.04433950409293175, + 0.04642309248447418, + 0.0022602551616728306, + -0.0361824594438076, + -0.0005105047021061182, + 0.030808264389634132, + 0.0022333709057420492, + -0.017826544120907784, + -0.03796307370066643, + -0.012887164019048214, + -0.028499294072389603, + -0.03367336839437485, + -0.03668365254998207, + -0.02807682938873768, + -0.07444571703672409, + -0.081318199634552, + -0.09610070288181305, + -0.05368436127901077, + -0.09006591141223907, + -0.10038736462593079, + -0.04115951433777809, + -0.056811004877090454, + -0.09935522079467773, + -0.11107856035232544, + -0.07852742075920105, + -0.0942930206656456, + -0.07625897973775864, + -0.12966541945934296, + -0.038938648998737335, + 0.04580259323120117, + 0.10179819911718369, + 0.17127273976802826, + 0.17857632040977478, + 0.13426578044891357, + 0.04687841981649399, + 0.2424812912940979, + 0.42633309960365295, + 0.5291624069213867, + 0.6012980937957764, + 0.5449428558349609, + 0.3945220708847046, + 0.07037744671106339, + 0.26918724179267883, + 0.44614800810813904, + 0.5331310629844666, + 0.568580687046051, + 0.43367546796798706, + 0.25516101717948914, + 0.08428427577018738, + 0.177769735455513, + 0.24885930120944977, + 0.2178547978401184, + 0.13834305107593536, + 0.07452446967363358, + 0.005187708884477615, + 0.050621017813682556, + -0.08428733795881271, + -0.15576106309890747, + -0.25531095266342163, + -0.34646397829055786, + -0.3276817202568054, + -0.24377694725990295, + 0.02817704901099205, + -0.2531633675098419, + -0.3907041549682617, + -0.5944734811782837, + -0.6062930822372437, + -0.5171639919281006, + -0.3501560389995575, + -0.019397703930735588, + -0.2758809030056, + -0.4118667244911194, + -0.5375933051109314, + -0.5525977611541748, + -0.44681206345558167, + -0.2748269736766815, + -0.04229651764035225, + -0.005005967803299427, + -0.011332424357533455, + 0.011387092061340809, + -0.015463154762983322, + -0.012038768269121647, + 0.011360889300704002, + 0.03551746904850006, + 0.05123865604400635, + 0.020377267152071, + 0.1065637394785881, + 0.18875306844711304, + 0.18516196310520172, + 0.12519532442092896, + -0.042940977960824966, + -0.03246130794286728, + -0.016645772382616997, + 0.07807288318872452, + -0.7815885543823242, + -0.5930942296981812, + 0.03312799707055092, + -0.04537777230143547, + -0.022234303876757622, + 0.009241255931556225, + 0.16947965323925018, + -0.0700032040476799, + -0.06346366554498672, + 0.09555318206548691, + 0.02858082763850689, + 0.009246457368135452, + 0.03902693837881088, + 0.007071994710713625, + 0.10085106641054153, + 0.0881502702832222, + 0.011019160971045494, + 0.006030070595443249, + -0.012882355600595474, + -0.01701420359313488, + 0.022596944123506546, + -0.05345382168889046, + 0.02355102449655533, + -0.0091088330373168, + 0.00015542628534603864, + -0.0004997836658731103, + -0.006951311603188515, + 0.01267238613218069, + -0.0033983420580625534, + -0.0030770134180784225, + 0.02975126914680004, + 0.010702245868742466, + -0.016947058960795403, + 0.007774800062179565, + 0.09566964209079742, + 0.07426714897155762, + 0.1621979922056198, + 0.12728945910930634, + 0.06112523376941681, + 0.06061968579888344, + 0.07934501022100449, + 0.11534841358661652, + 0.10001469403505325, + 0.15475066006183624, + 0.1828109323978424, + 0.02134544588625431, + -0.015320047736167908, + 0.012000483460724354, + -0.014393450692296028, + -1.5520576238632202, + -1.2115217447280884, + 0.017239907756447792, + -0.007013735361397266, + 0.0019166347337886691, + 0.025112343952059746, + 0.1803419440984726, + -0.30807924270629883, + -0.33957329392433167, + 0.10846519470214844, + 0.06151076406240463, + 0.054799750447273254, + 0.06235412135720253, + 0.09605015069246292, + 0.16495031118392944, + 0.12624189257621765, + 0.12234552949666977, + 0.006969878450036049, + 0.0033541936427354813, + 0.008165130391716957, + 0.035377491265535355, + -0.03170061111450195, + 0.019396571442484856, + -0.011411413550376892, + 0.019043665379285812, + 0.00957057997584343, + 0.0055394587107002735, + 0.05569477006793022, + 0.0076510305516421795, + 0.018707536160945892, + 0.06073765829205513, + 0.006503407843410969, + -0.0058801183477044106, + -0.03229741007089615, + 0.0386439748108387, + 0.03167358413338661, + 0.027749545872211456, + -0.04634377732872963, + -0.00019781991431955248, + 0.024982664734125137, + 0.009453915059566498, + 0.1091528981924057, + 0.21055325865745544, + 0.23810525238513947, + 0.13829846680164337, + -0.019112061709165573, + -0.0014926757430657744, + 0.01856786385178566, + 0.10649964213371277, + -0.8599057793617249, + -0.6383436322212219, + 0.10839059948921204, + -0.038730181753635406, + -0.030203847214579582, + -0.033147793263196945, + 0.18132103979587555, + -0.1427767276763916, + -0.11132896691560745, + 0.10957232862710953, + -0.00349965482018888, + 0.03486581891775131, + 0.016247740015387535, + 0.060106489807367325, + 0.1439678966999054, + 0.07201634347438812, + 0.07603273540735245, + -0.0072280303575098515, + 0.01600506529211998, + -0.012912745587527752, + 0.015192546881735325, + -0.034853674471378326, + 0.026164958253502846, + 0.001483929343521595, + 0.0508253313601017, + -0.010546445846557617, + -0.024398569017648697, + -0.0043407524935901165, + 0.0030393539927899837, + -0.009643012657761574, + -0.008882591500878334, + 0.01182172168046236, + 0.003359999740496278, + -0.01145304087549448, + -7.34154018573463e-05, + 0.007416137028485537, + -0.012022661976516247, + 0.013550116680562496, + -0.005982181057333946, + -0.019205773249268532, + -0.0811527743935585, + -0.06323252618312836, + -0.026379290968179703, + -0.04671972244977951, + -0.006205265875905752, + 0.05242094770073891, + 0.05065605416893959, + 0.01961991749703884, + 0.021542323753237724, + 0.04147094115614891, + 0.04451332613825798, + 0.05155060812830925, + 0.15659169852733612, + 0.4448348879814148, + 0.7207449078559875, + 0.8680058717727661, + 0.7269517779350281, + 0.36259666085243225, + 0.10394725203514099, + -0.20449180901050568, + -0.42664405703544617, + -0.7290332317352295, + -0.9376083016395569, + -0.735107958316803, + -0.3541502356529236, + -0.23789332807064056, + -0.10901623964309692, + -0.26809337735176086, + -0.38465574383735657, + -0.44440212845802307, + -0.4070444703102112, + -0.22405119240283966, + -0.14190013706684113, + 0.07151509076356888, + 0.21848519146442413, + 0.41893038153648376, + 0.4783499836921692, + 0.4281534254550934, + 0.28631147742271423, + 0.057699400931596756, + 0.0029010034631937742, + -0.02580493874847889, + -0.02152368798851967, + -0.025850815698504448, + 0.004789783153682947, + 0.021941278129816055, + 0.00574735039845109, + -0.004016151186078787, + -0.014377521350979805, + -0.0828985944390297, + -0.06380187720060349, + -0.048879947513341904, + -0.04580164700746536, + -0.030843649059534073, + 0.024663949385285378, + 0.03409295156598091, + 0.060452476143836975, + 0.037006158381700516, + 0.058853648602962494, + 0.07275765389204025, + 0.02882941998541355, + 0.14549848437309265, + 0.4268765151500702, + 0.7150183320045471, + 0.8942612409591675, + 0.7532845139503479, + 0.3846176564693451, + 0.15604183077812195, + -0.19108416140079498, + -0.42633384466171265, + -0.7508237361907959, + -0.9448286890983582, + -0.719300389289856, + -0.3583783805370331, + -0.2060524821281433, + -0.10382426530122757, + -0.2624296545982361, + -0.4049411416053772, + -0.4338999092578888, + -0.41390693187713623, + -0.22797809541225433, + -0.14593803882598877, + 0.08197329193353653, + 0.2430788278579712, + 0.3906225562095642, + 0.47147202491760254, + 0.42429792881011963, + 0.29326340556144714, + 0.06683206558227539, + 0.004355552606284618, + -0.007973028346896172, + 0.0035172239877283573, + -0.0018502225866541266, + -0.015291260555386543, + 0.0025160792283713818, + 0.0015979957534000278, + 0.011951611377298832, + -0.0004334237310104072, + -0.00172338483389467, + 0.017284434288740158, + -0.00445173867046833, + -0.004828867502510548, + 0.004030159674584866, + 0.03321678191423416, + -0.016998661682009697, + -0.029765218496322632, + -0.07912255078554153, + -0.0494595468044281, + 0.012136446312069893, + 0.029541414231061935, + -0.01129366084933281, + 0.09502168744802475, + 0.21533286571502686, + 0.3453419804573059, + 0.22987395524978638, + 0.04720258712768555, + 0.0032486498821526766, + -0.0042808204889297485, + -0.10162857174873352, + -0.21601493656635284, + -0.3040534257888794, + -0.19600912928581238, + -0.0568307563662529, + -0.0062937624752521515, + -0.021828925237059593, + -0.03831009939312935, + -0.08992031216621399, + -0.08103442937135696, + -0.07600760459899902, + -0.02319694682955742, + -0.008472982794046402, + -0.004151565954089165, + 0.05002164468169212, + 0.0985124409198761, + 0.11273156106472015, + 0.10279814153909683, + 0.032678257673978806, + -0.023295480757951736, + -0.022312145680189133, + 0.032877422869205475, + 0.08301658928394318, + -0.049675002694129944, + -0.05956050381064415, + 0.006878976244479418, + 0.011597251519560814, + -0.03617611899971962, + -0.005020621232688427, + 0.0066283573396503925, + 0.061849869787693024, + 0.0668889507651329, + -0.1120104044675827, + 0.0215831957757473, + -0.008177083916962147, + 0.019240612164139748, + -0.03794482350349426, + -0.21581093966960907, + 0.3248063623905182, + 0.0525924488902092, + -0.13873063027858734, + -0.030904211103916168, + -0.004122832324355841, + 0.2784009277820587, + -0.42068102955818176, + -0.15351417660713196, + 0.4266241192817688, + -0.10780557245016098, + 0.03840374946594238, + -0.15116721391677856, + 0.2292502224445343, + 0.23400554060935974, + -0.5023872256278992, + 0.14868289232254028, + 0.09809935092926025, + 0.03480924293398857, + -0.046804867684841156, + -0.14212554693222046, + 0.3073779344558716, + -0.029529480263590813, + -0.13998086750507355, + -0.02750661037862301, + 0.010526027530431747, + 0.032874979078769684, + -0.07645174115896225, + -0.02746269293129444, + 0.10902399569749832, + -0.00446560001000762, + -0.01339190173894167, + 0.003540819976478815, + -0.04410126060247421, + -0.10884726047515869, + 0.016081949695944786, + 0.15211890637874603, + 0.04027504846453667, + -0.05552368983626366, + 0.04718002676963806, + 0.014503135345876217, + -0.2764658033847809, + -0.16068166494369507, + 0.3356778621673584, + 0.06485499441623688, + -0.07164154946804047, + 0.084479421377182, + 0.2702949047088623, + -0.1339409202337265, + -0.9642015695571899, + 0.47433769702911377, + 0.4715694189071655, + -0.17669782042503357, + -0.04434441775083542, + 0.2641690671443939, + 0.7357130646705627, + -1.2222046852111816, + -0.8205837607383728, + 0.9091072678565979, + 0.14896778762340546, + -0.09332367032766342, + -0.16173647344112396, + 0.8782246708869934, + 0.3819980323314667, + -1.619883418083191, + 0.059255462139844894, + 0.42745286226272583, + -0.03186821565032005, + -0.16420172154903412, + 0.12124066799879074, + 0.8650834560394287, + -0.3728218674659729, + -0.5816569328308105, + 0.10949260741472244, + -0.010671291500329971, + -0.07903271913528442, + -0.09700250625610352, + 0.3192030191421509, + 0.2756008505821228, + -0.2616698145866394, + -0.11051242798566818, + 0.016789941117167473, + -0.0484573096036911, + -0.12333080172538757, + 0.0158428642898798, + 0.11172449588775635, + 0.014953864738345146, + -0.011746960692107677, + 0.05310823395848274, + 0.030244171619415283, + -0.23969320952892303, + -0.1039247065782547, + 0.285805881023407, + -0.04652552306652069, + -0.05380000174045563, + 0.05430186912417412, + 0.25547218322753906, + -0.06164371967315674, + -0.7386756539344788, + 0.4393811821937561, + 0.2623714804649353, + -0.1849273294210434, + -0.049713607877492905, + 0.1656467467546463, + 0.6638666391372681, + -0.899787187576294, + -0.5747878551483154, + 0.7465870976448059, + -0.025567445904016495, + -0.051771312952041626, + -0.19754628837108612, + 0.6828271746635437, + 0.4451557695865631, + -1.2559787034988403, + 0.07448688894510269, + 0.27905938029289246, + 0.003908769693225622, + -0.18454433977603912, + -0.011183545924723148, + 0.7449039816856384, + -0.228777676820755, + -0.47592073678970337, + 0.13784541189670563, + 0.019371675327420235, + -0.06424596160650253, + -0.1660400629043579, + 0.2080633044242859, + 0.2942465841770172, + -0.20263032615184784, + -0.0709841251373291, + -0.0021153483539819717, + -0.028180474415421486, + -0.021557176485657692, + 0.012511649169027805, + 0.06533018499612808, + 0.006560645066201687, + -0.01908997632563114, + -0.020228691399097443, + 0.10450740903615952, + 0.04476405307650566, + -0.20389842987060547, + -0.36356496810913086, + -0.18690945208072662, + 0.06581642478704453, + 0.005246834829449654, + -0.14777734875679016, + 0.04554577171802521, + 0.7314760088920593, + 1.1759854555130005, + 0.7747871279716492, + 0.08771117031574249, + 0.04425497353076935, + 0.14875195920467377, + -0.05036012455821037, + -1.0561891794204712, + -1.7835016250610352, + -1.313464879989624, + -0.4041728973388672, + -0.08825081586837769, + -0.18483860790729523, + -0.09619659930467606, + 0.6506555676460266, + 1.2331949472427368, + 1.057729721069336, + 0.3030258119106293, + 0.053314659744501114, + 0.10696353763341904, + 0.19720971584320068, + -0.19457301497459412, + -0.3546113669872284, + -0.3773464560508728, + 0.007737448439002037, + 0.007112926337867975, + -0.026632368564605713, + -0.07708505541086197, + 0.016982559114694595, + 0.03331448882818222, + 0.03235285356640816, + -0.04479134455323219, + 0.0062864539213478565, + -0.04983896017074585, + -0.014209658838808537, + 0.025105496868491173, + 0.07187403738498688, + -0.019782420247793198, + -0.0387532040476799, + 0.01098113413900137, + 0.10765481740236282, + -0.005502769257873297, + -0.29967597126960754, + -0.5370010733604431, + -0.25729984045028687, + 0.0341138020157814, + -0.01927473582327366, + -0.11736954003572464, + 0.09457080066204071, + 0.8881804943084717, + 1.5049697160720825, + 1.0347492694854736, + 0.22410355508327484, + -0.004720119293779135, + 0.1449226438999176, + -0.11916695535182953, + -1.2009364366531372, + -2.080855369567871, + -1.5549882650375366, + -0.5231477618217468, + -0.005029830615967512, + -0.11258674412965775, + 0.03710457682609558, + 0.9192798137664795, + 1.525830626487732, + 1.3018689155578613, + 0.44408130645751953, + 0.006972550880163908, + 0.07937697321176529, + 0.060622286051511765, + -0.4068094491958618, + -0.5964561104774475, + -0.6058750152587891, + -0.1743212193250656, + -0.0038881103973835707, + -0.04932431876659393, + -0.04989266395568848, + 0.07228495925664902, + 0.10359980911016464, + 0.11054171621799469, + 0.017031395807862282, + -0.012849675491452217, + -0.02224516123533249, + -0.019851619377732277, + 0.04567919671535492, + 0.12134519219398499, + 0.018673665821552277, + -0.03933878242969513, + 0.03506385162472725, + 0.07499910145998001, + -0.004981306381523609, + -0.269795298576355, + -0.4478399455547333, + -0.3141564130783081, + 0.014856644906103611, + -0.01102763693779707, + -0.11778493225574493, + -0.00048367868294008076, + 0.46917271614074707, + 0.8380635976791382, + 0.5829758048057556, + 0.14924737811088562, + 0.00504975114017725, + 0.1242799386382103, + 0.027800291776657104, + -0.5343790054321289, + -0.9185061454772949, + -0.6974499225616455, + -0.1733488291501999, + 0.028415951877832413, + -0.07513032108545303, + 0.010947657749056816, + 0.5501428246498108, + 0.8556726574897766, + 0.6854383945465088, + 0.21023745834827423, + -0.04757346957921982, + 0.028925150632858276, + -0.05005616322159767, + -0.4106282889842987, + -0.5990055203437805, + -0.5274976491928101, + -0.18928098678588867, + 0.007199999876320362, + 0.004744168370962143, + -0.006203897297382355, + 0.16117095947265625, + 0.20310591161251068, + 0.17358633875846863, + 0.057794276624917984, + 0.0018837900133803487, + -0.021730661392211914, + 0.03705505281686783, + 0.048999205231666565, + 0.017187459394335747, + -0.04760497808456421, + -0.06534644961357117, + 0.027641354128718376, + -0.02722003310918808, + -0.09557735174894333, + 0.2721945643424988, + 0.06861108541488647, + -0.17862513661384583, + 0.029542427510023117, + -0.028343068435788155, + -0.24357359111309052, + 0.2928915321826935, + 0.6317090392112732, + -0.5675624012947083, + -0.31298428773880005, + 0.119928739964962, + -0.04503166303038597, + 0.1997436285018921, + 0.9068917632102966, + -0.6105388402938843, + -1.176649808883667, + 0.391012579202652, + 0.21436090767383575, + 0.06404570490121841, + 0.4306352436542511, + -0.18372972309589386, + -1.6093186140060425, + 0.5129231810569763, + 0.8333584666252136, + -0.11607109010219574, + 0.024050598964095116, + -0.027272621169686317, + -0.8072280883789062, + 0.15613007545471191, + 1.0115277767181396, + -0.1886059194803238, + -0.1662863790988922, + -0.07484262436628342, + -0.11359186470508575, + -0.05765556916594505, + 0.48085057735443115, + 0.031143836677074432, + -0.20803743600845337, + 0.005643316078931093, + -0.011422591283917427, + -0.02063453011214733, + 0.010139239020645618, + 0.026931140571832657, + 0.02650240994989872, + 0.014503400772809982, + -0.030498046427965164, + 0.01038119662553072, + -0.041832923889160156, + -0.11747029423713684, + 0.24838468432426453, + 0.08126607537269592, + -0.17684465646743774, + 0.009867151267826557, + -0.04349489137530327, + -0.22892898321151733, + 0.3097872734069824, + 0.6229272484779358, + -0.5710748434066772, + -0.2540203332901001, + 0.15970031917095184, + -0.05765099450945854, + 0.24631772935390472, + 0.9121918678283691, + -0.6539115309715271, + -1.1680796146392822, + 0.43742635846138, + 0.1981748640537262, + 0.060766786336898804, + 0.48115089535713196, + -0.2704729437828064, + -1.668082594871521, + 0.6258481740951538, + 0.8217618465423584, + -0.17844447493553162, + 0.07583325356245041, + -0.031355466693639755, + -0.884739100933075, + 0.21298757195472717, + 1.0279508829116821, + -0.2118954360485077, + -0.16616611182689667, + -0.025157395750284195, + -0.11329160630702972, + -0.08147483319044113, + 0.46636614203453064, + 0.023730026558041573, + -0.21343427896499634, + -0.015201984904706478, + -0.00498165050521493, + 0.022955382242798805, + 0.020228328183293343, + -0.029405873268842697, + -0.032065436244010925, + 0.047389160841703415, + -0.01793060638010502, + 0.01669210195541382, + 0.05227159336209297, + -0.11703876405954361, + 0.006789325270801783, + 0.03741219639778137, + -0.04651298373937607, + -0.012846981175243855, + 0.024231625720858574, + -0.13399703800678253, + -0.024073680862784386, + 0.2970501184463501, + -0.1497301310300827, + -0.04287628084421158, + 0.08405227214097977, + -0.06020639091730118, + -0.01648692972958088, + 0.4150170087814331, + -0.17000712454319, + -0.43461430072784424, + 0.27202337980270386, + 0.006708468310534954, + -0.04474359005689621, + 0.15199843049049377, + -0.03348325565457344, + -0.6591396331787109, + 0.4057810306549072, + 0.25226324796676636, + -0.16070741415023804, + 0.03464199975132942, + 0.023064177483320236, + -0.35642316937446594, + 0.22774185240268707, + 0.37138837575912476, + -0.24171461164951324, + -0.023513946682214737, + 0.028774995356798172, + -0.02702418342232704, + -0.012504744343459606, + 0.17893734574317932, + -0.1554262489080429, + -0.09501983970403671, + 0.06177212670445442, + -0.013536165468394756, + 0.012441401369869709, + 0.006566522642970085, + -0.018207622691988945, + 0.003373368876054883, + -0.034891802817583084, + 0.002223123563453555, + 0.006169564090669155, + 0.022658145055174828, + -0.005327044054865837, + -0.023764559999108315, + -0.004386506043374538, + -0.02777106687426567, + 0.01950058527290821, + 0.004401096608489752, + 0.02882237359881401, + 0.01790205016732216, + -0.007827110588550568, + -0.005222277250140905, + -0.05361752584576607, + 0.008359426632523537, + -0.026494475081562996, + -0.015572195872664452, + -0.04412947595119476, + -0.006163781508803368, + 0.180303692817688, + 0.17117105424404144, + -0.014117442071437836, + 0.014543564058840275, + 0.03875281661748886, + 0.002004631096497178, + 0.11982911080121994, + 0.609316349029541, + 0.5792325735092163, + 0.10267578810453415, + -0.02287464588880539, + -0.011516223661601543, + -0.02587946131825447, + 0.019127164036035538, + 0.2742871046066284, + 0.23896890878677368, + -0.013414637185633183, + 0.012439075857400894, + 0.01148916780948639, + 0.0024075021501630545, + -0.028374193236231804, + -0.02938784286379814, + -0.061723873019218445, + -0.03288640081882477, + 0.010918691754341125, + 0.01171314436942339, + 0.00894222967326641, + -0.0050367508083581924, + 0.00322812981903553, + -0.01958087645471096, + 0.000401448953198269, + 0.00655051926150918, + 0.008647873997688293, + -0.015351405367255211, + -0.022286182269454002, + -0.0018973759142681956, + -0.032965533435344696, + 0.009401706047356129, + 0.01680464670062065, + 0.01722409576177597, + 0.017367251217365265, + -0.0012145076179876924, + 0.015895379707217216, + -0.013976357877254486, + 0.01587546430528164, + -0.019388504326343536, + -0.004597584251314402, + -0.026080038398504257, + 0.020517753437161446, + 0.20680218935012817, + 0.20302064716815948, + 0.03813354671001434, + 0.027738921344280243, + 0.02183712273836136, + 0.023807305842638016, + 0.14632326364517212, + 0.5991678237915039, + 0.608651340007782, + 0.15929070115089417, + -0.02112223394215107, + -0.020013611763715744, + -0.03723381832242012, + 0.032139480113983154, + 0.27032363414764404, + 0.24862462282180786, + 0.02374681644141674, + 0.007894856855273247, + 0.00042308925185352564, + -0.004832752980291843, + -0.024313796311616898, + -0.0018940505106002092, + -0.02681432105600834, + 0.002362651750445366, + 0.013330202549695969, + 0.012553646229207516, + 0.002630018163472414, + 0.002979951212182641, + 0.0015847217291593552, + -0.03376828506588936, + -0.010844729840755463, + -0.002748559694737196, + 0.012938202358782291, + -0.011872833594679832, + -0.0025761008728295565, + 0.003677211469039321, + -0.04305516183376312, + 0.001133457524701953, + 0.0020396243780851364, + 0.01797032356262207, + 0.016580887138843536, + 0.04445189982652664, + 0.013270077295601368, + -0.04839251935482025, + 0.011546633206307888, + -0.015829432755708694, + 0.019473392516374588, + -0.011464826762676239, + 0.018693143501877785, + 0.18201367557048798, + 0.16157257556915283, + 0.02082117274403572, + 0.015915032476186752, + 0.010720869526267052, + -0.0020238866563886404, + 0.09329187124967575, + 0.46998023986816406, + 0.5186727046966553, + 0.09814783185720444, + -0.016547314822673798, + 0.00325066689401865, + -0.028936590999364853, + 0.01002424769103527, + 0.21822214126586914, + 0.22012007236480713, + 0.008229314349591732, + 0.015599996782839298, + 0.014740276150405407, + 0.0019725109450519085, + 0.003613655688241124, + -0.03043546713888645, + -0.06308998167514801, + 0.014664110727608204, + 0.06775129586458206, + -0.12990300357341766, + -0.03638269379734993, + -0.03883139044046402, + 0.05194637551903725, + 0.03896122798323631, + -0.05132362246513367, + -0.07234688848257065, + -0.36106064915657043, + -0.2839237451553345, + -0.11496391147375107, + 0.3026673197746277, + 0.3528609871864319, + 0.21559017896652222, + -0.11970120668411255, + -0.5473688244819641, + -0.5362005233764648, + -0.21015112102031708, + 0.4089161455631256, + 0.6033567786216736, + 0.38614287972450256, + -0.12437233328819275, + -0.6394402384757996, + -0.6945835947990417, + -0.3482857942581177, + 0.5189254283905029, + 0.8457668423652649, + 0.6248002648353577, + -0.12700730562210083, + -0.6978924870491028, + -0.7764106392860413, + -0.4171960651874542, + 0.44747814536094666, + 0.8406224846839905, + 0.6821274161338806, + -0.07793218642473221, + -0.5459966659545898, + -0.6139025092124939, + -0.35998886823654175, + 0.27800890803337097, + 0.6048891544342041, + 0.591307520866394, + -0.04850815609097481, + -0.3863481283187866, + -0.3542836606502533, + -0.2491992861032486, + 0.1616278886795044, + 0.3402666747570038, + 0.4610227644443512, + -0.010262396186590195, + 0.0408165417611599, + 0.006382474210113287, + -0.011430315673351288, + -0.027895113453269005, + -0.009767768904566765, + 0.005882019177079201, + 0.05225436016917229, + 0.0415218211710453, + 0.08244743943214417, + 0.026765575632452965, + -0.05404946208000183, + -0.06101839989423752, + -0.028233220800757408, + 0.03128793090581894, + 0.07133004069328308, + 0.0718698799610138, + 0.042146697640419006, + -0.08380170166492462, + -0.09263177216053009, + -0.07569421827793121, + 0.032425008714199066, + 0.12351400405168533, + 0.09103626012802124, + -0.004768018145114183, + -0.05960838869214058, + -0.11922567337751389, + -0.10132396221160889, + 0.044341862201690674, + 0.100867860019207, + 0.09607693552970886, + -0.00129030947573483, + -0.05481477826833725, + -0.1278291642665863, + -0.12058380991220474, + 0.016678951680660248, + 0.09958931058645248, + 0.08456224203109741, + 0.061599165201187134, + -0.049776893109083176, + -0.11354166269302368, + -0.09844806790351868, + 0.004753128159791231, + 0.07868346571922302, + 0.06464104354381561, + 0.020981626585125923, + -0.010770543478429317, + -0.08838209509849548, + -0.07265795767307281, + -0.058313023298978806, + 0.10897739976644516, + 0.026735201478004456, + 0.03972309082746506, + -0.019998662173748016, + -0.048948734998703, + 0.03377270698547363, + 0.053406376391649246, + 0.27304399013519287, + 0.20850272476673126, + 0.07890326529741287, + -0.22241365909576416, + -0.2816997468471527, + -0.1745096743106842, + 0.08957889676094055, + 0.4962941110134125, + 0.4586986303329468, + 0.20177948474884033, + -0.3625744581222534, + -0.47758376598358154, + -0.32412785291671753, + 0.0669194757938385, + 0.5394997596740723, + 0.601328432559967, + 0.24388420581817627, + -0.4319041073322296, + -0.6893490552902222, + -0.5106037259101868, + 0.10174300521612167, + 0.5457565784454346, + 0.6549625992774963, + 0.38772058486938477, + -0.3778320252895355, + -0.6820934414863586, + -0.551069438457489, + 0.049600999802351, + 0.45137161016464233, + 0.5143972039222717, + 0.3713279068470001, + -0.26546329259872437, + -0.5121409893035889, + -0.47691628336906433, + 0.03843758627772331, + 0.30808231234550476, + 0.3185756504535675, + 0.22629432380199432, + -0.14860986173152924, + -0.2915389835834503, + -0.3552006185054779, + -0.003137432038784027, + -0.01327254343777895, + -0.027139298617839813, + 0.04800891876220703, + 0.05380738899111748, + -0.01380784809589386, + 0.0022881641052663326, + -0.012132279574871063, + 0.06182793900370598, + 0.03762871399521828, + 0.0966145321726799, + 0.08963571488857269, + 0.06551238149404526, + 0.031640589237213135, + -0.010532311163842678, + 0.07195396721363068, + 0.11343465745449066, + 0.11621421575546265, + 0.047318290919065475, + 0.1111951395869255, + 0.044054243713617325, + 0.016777141019701958, + 0.03392713516950607, + 0.06047024950385094, + -0.7924502491950989, + -0.7310910224914551, + 0.031088173389434814, + 0.0906061977148056, + 0.022829236462712288, + 0.04470035433769226, + 0.025999872013926506, + -0.8246837258338928, + -0.723675549030304, + 0.15835590660572052, + 0.07358791679143906, + -0.015819497406482697, + -0.014207872562110424, + 0.08506257086992264, + 0.08868777751922607, + 0.0976945012807846, + 0.11740022897720337, + 0.016287995502352715, + -0.024363648146390915, + 0.04249691963195801, + 0.02909177541732788, + 0.12011238187551498, + 0.10729824751615524, + 0.05927390977740288, + 0.04731644690036774, + 0.008210064843297005, + 0.03859357163310051, + -0.005175672471523285, + 0.01984376832842827, + -0.0011626111809164286, + -0.0010909241391345859, + 0.02311880886554718, + 0.007646523881703615, + 0.04582137614488602, + -0.0027255103923380375, + 0.027656713500618935, + 0.02781369723379612, + 0.015750093385577202, + 0.040563344955444336, + -0.007784596644341946, + 0.006534814368933439, + 0.002403199439868331, + -0.020037032663822174, + -0.011717663146555424, + 0.07826739549636841, + 0.018203573301434517, + 0.021228624507784843, + 0.014112413860857487, + -0.02866269089281559, + -0.9502679109573364, + -0.825043797492981, + 0.05938851460814476, + 0.06553053110837936, + 0.015418429858982563, + 0.0616452619433403, + -0.0094453701749444, + -0.9471839666366577, + -0.7922234535217285, + 0.13069523870944977, + 0.04939320683479309, + 0.007429714780300856, + 0.022599652409553528, + 0.0820123627781868, + 0.06440276652574539, + 0.09897352755069733, + 0.0856291800737381, + 0.006608777679502964, + -0.0005533680086955428, + 0.021656949073076248, + 0.014818831346929073, + 0.03757459297776222, + -0.001428246614523232, + 0.03473127633333206, + 0.03607869893312454, + 0.017313262447714806, + 0.0025767614133656025, + -0.033292777836322784, + 0.027883101254701614, + -0.007534499745815992, + -0.04302362725138664, + -0.01795666106045246, + -0.007667913101613522, + 0.012547189369797707, + -0.021762438118457794, + 0.03789107874035835, + 0.06384614109992981, + 0.0014223429607227445, + -0.01393786258995533, + -0.041693057864904404, + -0.01813604310154915, + 0.065328449010849, + 0.15736474096775055, + 0.1531635969877243, + 0.09920474886894226, + -0.04044449329376221, + 0.010558396577835083, + 0.05559245124459267, + 0.10931257158517838, + -0.5784384608268738, + -0.5109886527061462, + 0.17690584063529968, + 0.07484250515699387, + 0.010378374718129635, + 0.0890144556760788, + 0.13172735273838043, + -0.6058865785598755, + -0.49908995628356934, + 0.1835336685180664, + 0.005293308291584253, + -0.03870566934347153, + -0.025229454040527344, + 0.12571711838245392, + 0.14792272448539734, + 0.14905226230621338, + 0.0700206533074379, + -0.035034529864788055, + 0.013128797523677349, + 0.015581230632960796, + 0.005400130525231361, + 0.07070232182741165, + 0.03829728811979294, + -0.013876918703317642, + -0.019958000630140305, + -0.020086020231246948, + -0.019999003037810326, + -0.015111410059034824, + 0.11963249742984772, + -0.08270428329706192, + -0.0025947154499590397, + -0.010668564587831497, + 0.016670405864715576, + -0.03206938877701759, + -0.053453829139471054, + 0.1236601173877716, + -0.020077411085367203, + 0.00779569149017334, + -0.0318986251950264, + 0.03579804673790932, + -0.060723867267370224, + -0.009301809594035149, + 0.09249342232942581, + -0.13378725945949554, + 0.17496798932552338, + -0.0935625433921814, + 0.06569044291973114, + -0.18187756836414337, + 0.06397300213575363, + 0.3793930113315582, + -0.5664302706718445, + 0.23658618330955505, + -0.03206830099225044, + 0.03155658766627312, + 0.039305318146944046, + -0.6008145213127136, + 1.0417630672454834, + -0.5062726140022278, + -0.04698493704199791, + 0.0979752242565155, + -0.037326715886592865, + 0.26255178451538086, + -0.590207576751709, + 0.4195419251918793, + 0.12212422490119934, + -0.26122942566871643, + 0.06442253291606903, + -0.07682429254055023, + 0.12608948349952698, + -0.13872937858104706, + -0.030260663479566574, + 0.2047160565853119, + -0.13068141043186188, + 0.016608506441116333, + -0.021629147231578827, + 0.04659907519817352, + 0.024417348206043243, + 0.06751634925603867, + -0.1705978959798813, + 0.0655774399638176, + -0.0041802311316132545, + -0.02263445220887661, + -0.014069054275751114, + 0.06242800131440163, + 0.08984102308750153, + -0.19382472336292267, + 0.09380361437797546, + -0.0032764992211014032, + -0.03950225189328194, + -0.08896161615848541, + 0.28387022018432617, + 0.1668996810913086, + -0.5457127094268799, + 0.21796099841594696, + 0.012032964266836643, + 0.030721815302968025, + -0.4431600570678711, + 0.3104412257671356, + 1.0070439577102661, + -1.1077969074249268, + 0.08187273889780045, + 0.1387241780757904, + 0.09014563262462616, + -0.25378379225730896, + -0.9253583550453186, + 1.9745515584945679, + -0.6605072617530823, + -0.4394792318344116, + 0.11501576751470566, + 0.03007262572646141, + 0.2538164258003235, + -1.1462018489837646, + 0.7988958954811096, + 0.46934643387794495, + -0.4244523048400879, + -0.0001816617150325328, + -0.04351970925927162, + 0.20500127971172333, + -0.40710335969924927, + -0.15871365368366241, + 0.4640160799026489, + -0.06024328991770744, + -0.016036653891205788, + -0.012419192120432854, + 0.05552554875612259, + 0.050986770540475845, + -0.0171927809715271, + -0.12105240672826767, + 0.03947274759411812, + 0.009537882171571255, + -0.026668362319469452, + 0.017273351550102234, + 0.10812800377607346, + -0.015008139424026012, + -0.14154496788978577, + 0.08008233457803726, + -0.01306608971208334, + -0.05574854835867882, + -0.06091056764125824, + 0.2888447940349579, + 0.05022002384066582, + -0.4581625759601593, + 0.21146118640899658, + -0.01495362538844347, + 0.02946372702717781, + -0.38554418087005615, + 0.30167311429977417, + 0.7605867981910706, + -0.898481547832489, + 0.11953620612621307, + 0.12686115503311157, + 0.09949854761362076, + -0.14409342408180237, + -0.7404491901397705, + 1.5449001789093018, + -0.5307857394218445, + -0.3347839415073395, + 0.09940771013498306, + 0.009087899699807167, + 0.3081797957420349, + -0.9053899049758911, + 0.5102643370628357, + 0.4646914303302765, + -0.36200836300849915, + -0.043260715901851654, + -0.05309509113430977, + 0.22480911016464233, + -0.2674587666988373, + -0.25316888093948364, + 0.435017466545105, + -0.017485838383436203, + -0.049459364265203476, + 0.012460661120712757, + -0.02262282371520996, + -0.04392899200320244, + 0.013330060057342052, + 0.05963548645377159, + -0.020561739802360535, + -0.013496879488229752, + -0.02310933545231819, + -0.06549905985593796, + 0.12132573872804642, + 0.22165189683437347, + -0.07683887332677841, + -0.12427931278944016, + 0.05543455854058266, + 0.009089780040085316, + 0.19844494760036469, + 0.07650767266750336, + -0.48934996128082275, + -0.35080164670944214, + 0.13422781229019165, + 0.022217294201254845, + -0.006589306052774191, + -0.18357548117637634, + -0.6055922508239746, + 0.09492127597332001, + 0.7073907256126404, + 0.1777055710554123, + -0.05434347689151764, + 0.04566245526075363, + -0.023967979475855827, + 0.4856843054294586, + 0.8131930828094482, + -0.2068077027797699, + -0.3863125145435333, + 0.02887917123734951, + -0.05048410966992378, + 0.051201049238443375, + 0.057671088725328445, + -0.6412642002105713, + -0.39739903807640076, + 0.11036981642246246, + 0.06687764078378677, + -0.018151026219129562, + 0.0022760110441595316, + -0.09328305721282959, + 0.1352599710226059, + 0.19680921733379364, + 0.032235175371170044, + -0.06123670935630798, + -0.013810456730425358, + -0.01821190118789673, + -0.029903864488005638, + 0.027588335797190666, + 0.0762094110250473, + -0.046041399240493774, + 0.017117975279688835, + -0.018925148993730545, + 0.00423092395067215, + 0.2065701186656952, + 0.157025545835495, + -0.26491472125053406, + -0.24569831788539886, + 0.0873267725110054, + 0.004694689530879259, + 0.1838335543870926, + -0.18973900377750397, + -0.9744532108306885, + -0.41959065198898315, + 0.409589946269989, + 0.22223009169101715, + -0.0989728644490242, + -0.40883490443229675, + -0.8418471813201904, + 0.40256521105766296, + 1.4742398262023926, + 0.4913789629936218, + -0.14741277694702148, + -0.0028576564509421587, + 0.0861843004822731, + 1.0056577920913696, + 1.479182481765747, + -0.21940617263317108, + -0.8383130431175232, + -0.30560192465782166, + 0.12028121203184128, + 0.24013034999370575, + 0.11750353127717972, + -1.1071972846984863, + -0.9066778421401978, + -0.055051110684871674, + 0.15361995995044708, + 0.0032418384216725826, + -0.08823435008525848, + -0.3188804090023041, + -0.02160414680838585, + 0.2972750663757324, + 0.17006494104862213, + 0.03401973098516464, + 0.017106015235185623, + 0.010733614675700665, + 0.004688877146691084, + 0.02985573373734951, + 0.046415988355875015, + -0.05177726596593857, + -0.04624386876821518, + 0.026672907173633575, + 0.03479000926017761, + 0.22761401534080505, + 0.12049756944179535, + -0.23494181036949158, + -0.2207801640033722, + 0.06036320701241493, + 0.02112250216305256, + 0.16173022985458374, + -0.14196650683879852, + -0.8236543536186218, + -0.3530665934085846, + 0.3715725541114807, + 0.25781863927841187, + -0.09806561470031738, + -0.341796338558197, + -0.7201419472694397, + 0.2111824005842209, + 1.1648427248001099, + 0.3866075575351715, + -0.1955428272485733, + -0.13164694607257843, + -0.06048528477549553, + 0.7989920973777771, + 1.143347144126892, + -0.19509637355804443, + -0.6719933152198792, + -0.26912447810173035, + 0.16733723878860474, + 0.32526257634162903, + 0.1910397708415985, + -0.8516904711723328, + -0.6005953550338745, + 0.10627525299787521, + 0.16700856387615204, + 0.032433755695819855, + -0.11345972120761871, + -0.270126610994339, + -0.012052524834871292, + 0.25489771366119385, + 0.14647918939590454, + -0.014324051328003407, + -0.011148945428431034, + -0.0011708218371495605, + -0.018903911113739014, + -0.010648071765899658, + -0.017981043085455894, + 0.014055400155484676, + -0.020784996449947357, + -0.030126383528113365, + 0.1150858998298645, + -0.1112036183476448, + -0.023664508014917374, + 0.1651369333267212, + -0.055412910878658295, + -0.007318025920540094, + -0.07404221594333649, + 0.3068569302558899, + -0.6175673007965088, + 0.35226404666900635, + 0.1940349042415619, + -0.22921296954154968, + 0.06411048769950867, + 0.001689439988695085, + 0.23336739838123322, + -0.9470900893211365, + 1.2042961120605469, + -0.44587329030036926, + -0.15847182273864746, + 0.07572423666715622, + 0.11138042062520981, + -0.2075018584728241, + -0.2651064693927765, + 0.8896074295043945, + -0.7130936980247498, + 0.10370831191539764, + 0.07730382680892944, + 0.02368813008069992, + -0.20520009100437164, + 0.13611918687820435, + 0.31062978506088257, + -0.471883624792099, + 0.21489326655864716, + -0.0216743852943182, + -0.04020361602306366, + -0.022920167073607445, + 0.16054102778434753, + -0.002624030224978924, + -0.14670424163341522, + 0.12018264085054398, + -0.043656397610902786, + -0.005084550939500332, + 0.03873870149254799, + -0.07967288792133331, + -0.007439201697707176, + 0.027688704431056976, + 0.08916077762842178, + -0.0036629599053412676, + -0.01389122661203146, + 0.1402083784341812, + -0.2923351228237152, + -0.01932896114885807, + 0.224355086684227, + -0.013193303719162941, + -0.03984276205301285, + -0.04474477842450142, + 0.3302844762802124, + -0.9746807217597961, + 0.5603556036949158, + 0.3556183874607086, + -0.2713812589645386, + 0.01890619471669197, + 0.06983876973390579, + 0.09052442759275436, + -1.3613605499267578, + 1.8220031261444092, + -0.40902698040008545, + -0.31302449107170105, + 0.03893759846687317, + 0.11448371410369873, + -0.4220678210258484, + -0.3677598237991333, + 1.539440631866455, + -0.8297391533851624, + -0.08504960685968399, + 0.0629446730017662, + -0.016804160550236702, + -0.31778836250305176, + 0.2363198846578598, + 0.6452136635780334, + -0.700931191444397, + 0.09927428513765335, + 0.0019635935313999653, + -0.05397690460085869, + -0.014552262611687183, + 0.2352754771709442, + 0.09991656988859177, + -0.28891685605049133, + 0.07818552106618881, + -0.021534763276576996, + -0.009461677633225918, + -0.01069199200719595, + -0.008059840649366379, + -0.0129952197894454, + 0.038492631167173386, + 0.018906958401203156, + -0.025432486087083817, + -0.03420932963490486, + 0.09104404598474503, + -0.10342919826507568, + -0.035048507153987885, + 0.1415904313325882, + -0.052986644208431244, + -0.021596742793917656, + -0.049690280109643936, + 0.3079117238521576, + -0.5487046837806702, + 0.27024003863334656, + 0.15158434212207794, + -0.16488635540008545, + 0.027642132714390755, + 0.004561549983918667, + 0.21555493772029877, + -0.9188903570175171, + 1.0972669124603271, + -0.3528037667274475, + -0.07574182748794556, + 0.021962830796837807, + 0.08826783299446106, + -0.18681983649730682, + -0.2789378762245178, + 0.864517331123352, + -0.5642455816268921, + 0.07469761371612549, + 0.03803368657827377, + 0.014268620871007442, + -0.17712704837322235, + 0.1349189728498459, + 0.3181247115135193, + -0.45067182183265686, + 0.1391848623752594, + 0.009777083061635494, + -0.028080958873033524, + -0.03586730733513832, + 0.14503192901611328, + -0.014655024744570255, + -0.1472700834274292, + 0.07361634075641632, + -0.0029754601418972015, + -0.006887470372021198, + -0.019166842103004456, + 0.0034907464869320393, + -0.015169994905591011, + 0.053831856697797775, + -0.028789488598704338, + -0.02033298648893833, + 0.0018537036376073956, + 0.07567961513996124, + -0.07041627168655396, + -0.047083087265491486, + 0.17573483288288116, + -0.04860217124223709, + 0.013171656988561153, + 0.020158233121037483, + -0.006270059384405613, + -0.28434091806411743, + 0.2760852873325348, + 0.32198208570480347, + -0.43535903096199036, + 0.03188510239124298, + 0.019360313192009926, + -0.20063988864421844, + 0.04450676590204239, + 0.9678076505661011, + -0.683987021446228, + -0.3979112207889557, + 0.2618143558502197, + -0.049711134284734726, + -0.06456997990608215, + 0.6518288850784302, + -0.1357039213180542, + -1.1304017305374146, + 0.4881652295589447, + 0.19583553075790405, + -0.03677722439169884, + 0.21429045498371124, + 0.09559855610132217, + -0.7311355471611023, + 0.10988117009401321, + 0.4949330687522888, + -0.17359353601932526, + 0.03822369873523712, + 0.011371256783604622, + -0.1900172382593155, + -0.04778448864817619, + 0.2897090017795563, + -0.02235160581767559, + -0.05582524091005325, + 0.007624597754329443, + -0.027456223964691162, + -0.029680097475647926, + -0.023810429498553276, + 0.15409281849861145, + 0.013284318149089813, + -0.0788225457072258, + -0.025637971237301826, + 0.01406402699649334, + -0.13676859438419342, + 0.027384959161281586, + 0.30458444356918335, + -0.11150643229484558, + -0.06806201487779617, + 0.009601237252354622, + -0.0866582989692688, + -0.2328706979751587, + 0.5188567638397217, + 0.3787381649017334, + -0.655829906463623, + 0.0072118742391467094, + -0.0031494891736656427, + -0.2424815446138382, + 0.28893929719924927, + 1.2396824359893799, + -1.0406886339187622, + -0.6376030445098877, + 0.4103420078754425, + -0.05929668992757797, + 0.03918358311057091, + 0.9274081587791443, + -0.28890565037727356, + -1.6682262420654297, + 0.66976398229599, + 0.35488471388816833, + 0.027932289987802505, + 0.3169145882129669, + 0.09107685089111328, + -1.2099432945251465, + 0.11623579263687134, + 0.7632684707641602, + -0.16506360471248627, + 0.037474747747182846, + -0.005203985143452883, + -0.35939401388168335, + -0.17138688266277313, + 0.525232195854187, + 0.10247340798377991, + -0.14317406713962555, + 0.007572649512439966, + -0.006046198774129152, + 0.06188087910413742, + -0.050851333886384964, + 0.032844241708517075, + 0.0544477179646492, + -0.07947597652673721, + -0.03073730878531933, + 0.04025515541434288, + -0.010001083835959435, + -0.11831062287092209, + 0.17422229051589966, + -0.05468267202377319, + -0.04996664077043533, + 0.023996006697416306, + 0.02888253889977932, + -0.18709556758403778, + 0.13987921178340912, + 0.32867854833602905, + -0.31714990735054016, + 0.019951285794377327, + 0.027247004210948944, + -0.19416090846061707, + -0.006519266404211521, + 0.7540720105171204, + -0.5474190711975098, + -0.27137213945388794, + 0.20772530138492584, + -0.042619917541742325, + -0.09566087275743484, + 0.548494815826416, + -0.1599852293729782, + -0.9178788661956787, + 0.5456539988517761, + 0.07497559487819672, + 0.003984459210187197, + 0.18640351295471191, + 0.12121234089136124, + -0.7249511480331421, + 0.2559764087200165, + 0.4684237241744995, + -0.19216996431350708, + 0.018075481057167053, + 0.02684594877064228, + -0.221074178814888, + -0.09164194762706757, + 0.3596596121788025, + -0.08310746401548386, + -0.10815230011940002, + -0.015406409278512001, + -0.011985878460109234, + 0.028467312455177307, + -0.0879230722784996, + 0.0347294844686985, + 0.05081191286444664, + 0.00362736196257174, + 0.010529003106057644, + -0.002672453410923481, + 0.025318201631307602, + -0.06232529878616333, + 0.008822780102491379, + 0.06744717806577682, + 0.003999210894107819, + -0.0022885131184011698, + -0.046704765409231186, + 0.13673964142799377, + -0.2590992748737335, + -0.022161437198519707, + 0.258914053440094, + -0.10650330036878586, + 0.023435762152075768, + 0.06992689520120621, + 0.03760937228798866, + -0.5444027185440063, + 0.4131152629852295, + 0.25325170159339905, + -0.2482522875070572, + 0.010479461401700974, + 0.045747850090265274, + -0.1541248857975006, + -0.35291528701782227, + 0.9078133702278137, + -0.34428781270980835, + -0.14787709712982178, + -0.024105649441480637, + -0.007651817053556442, + -0.14991067349910736, + 0.17544956505298615, + 0.3692120611667633, + -0.46861159801483154, + 0.10201738774776459, + 0.003734431229531765, + -0.010433703660964966, + 0.022045455873012543, + 0.0944862961769104, + 0.01679016835987568, + -0.16537833213806152, + 0.07900089025497437, + -0.004211293533444405, + -0.01076442189514637, + 0.09729930013418198, + -0.1490965485572815, + -0.02511671558022499, + 0.0766475573182106, + 0.010980346240103245, + -0.010220799595117569, + -0.0004861881607212126, + 0.09204736351966858, + -0.179045170545578, + -0.025164175778627396, + 0.15608654916286469, + 0.004787537269294262, + -0.0005253870622254908, + 0.034556396305561066, + 0.1509256660938263, + -0.5432079434394836, + -0.03155849874019623, + 0.513609766960144, + -0.14458952844142914, + 0.015178131870925426, + 0.09172039479017258, + -0.12612608075141907, + -0.926306962966919, + 0.8281942009925842, + 0.5954549908638, + -0.492740273475647, + 0.007195526268333197, + -0.018258413299918175, + -0.4074647128582001, + -0.43008187413215637, + 1.7370752096176147, + -0.350849986076355, + -0.5158001780509949, + -0.017458094283938408, + -0.08306471258401871, + -0.2334563285112381, + 0.445117712020874, + 0.7808031439781189, + -0.7913723587989807, + -0.11814796179533005, + -0.00913319457322359, + 0.0223994143307209, + 0.1012248545885086, + 0.25349485874176025, + 0.028286214917898178, + -0.4809858798980713, + 0.05953341722488403, + 0.015634188428521156, + 0.005101620219647884, + 0.10901974141597748, + -0.11964976042509079, + -0.09117673337459564, + 0.0734483003616333, + 0.01821213960647583, + 5.350751234800555e-05, + -0.020279232412576675, + 0.1097220927476883, + -0.1354990452528, + -0.08653146773576736, + 0.11775246262550354, + -0.012575668282806873, + 0.0310806967318058, + 0.010271146893501282, + 0.20337054133415222, + -0.3854014277458191, + -0.09943562000989914, + 0.3921409249305725, + -0.08432158827781677, + 0.010676748119294643, + 0.040244489908218384, + -0.0015478944405913353, + -0.7022866010665894, + 0.49858638644218445, + 0.42338883876800537, + -0.2982582449913025, + -0.005396307446062565, + -0.008777705952525139, + -0.2325415015220642, + -0.4083922803401947, + 1.186205506324768, + -0.26399391889572144, + -0.2621048092842102, + -0.015712907537817955, + -0.04675402492284775, + -0.1797540783882141, + 0.2992522716522217, + 0.4747498333454132, + -0.5266988277435303, + 0.04581758379936218, + -0.04037958011031151, + 0.0071074217557907104, + 0.047499995678663254, + 0.16617828607559204, + -0.03973710536956787, + -0.2953551113605499, + 0.10628587752580643, + -0.00904526561498642, + 0.010427894070744514, + 0.08035022020339966, + 0.03841109946370125, + -0.06335253268480301, + -0.06992083787918091, + 0.015409895218908787, + -0.026900725439190865, + -0.04523912072181702, + 0.08087682723999023, + 0.12542113661766052, + 0.018750213086605072, + -0.23430712521076202, + 0.11755944788455963, + -0.019747508689761162, + -0.03171322122216225, + -0.12132623791694641, + 0.2640603184700012, + 0.38445138931274414, + -0.5724408030509949, + 0.15661633014678955, + 0.01949799247086048, + -0.021771302446722984, + -0.18984957039356232, + -0.23499636352062225, + 1.2112919092178345, + -0.7037869095802307, + -0.14260035753250122, + 0.01848726160824299, + 0.06443414837121964, + -0.11740390956401825, + -0.8794785141944885, + 1.4160369634628296, + 0.016899125650525093, + -0.5444768071174622, + 0.017313210293650627, + 0.0508052259683609, + 0.11102095246315002, + -0.790285587310791, + 0.3501206636428833, + 0.7238660454750061, + -0.49468666315078735, + -0.019021952524781227, + -0.01212992612272501, + 0.15032203495502472, + -0.3573611080646515, + -0.1293754130601883, + 0.45295456051826477, + -0.08407819271087646, + -0.008717959746718407, + 0.022566653788089752, + -0.012640242464840412, + 0.03181227669119835, + 0.0638526976108551, + -0.058120664209127426, + -0.042917650192976, + 0.02129550836980343, + -0.018790805712342262, + -0.00655191857367754, + 0.05951414257287979, + 0.12890471518039703, + -0.1886381357908249, + 0.059096939861774445, + -0.016928592696785927, + 0.02327263168990612, + -0.17282842099666595, + 0.13812857866287231, + 0.38889989256858826, + -0.5282873511314392, + 0.07564643770456314, + -0.006128210574388504, + -0.00876594614237547, + -0.18427829444408417, + -0.26697441935539246, + 1.2529815435409546, + -0.6549165844917297, + -0.2111111879348755, + 0.011410325765609741, + 0.07089994102716446, + -0.12627695500850677, + -0.8245998024940491, + 1.4581915140151978, + -0.01822204887866974, + -0.5626582503318787, + -0.01661459542810917, + 0.03759436681866646, + 0.10841676592826843, + -0.7652962803840637, + 0.4360819458961487, + 0.7012669444084167, + -0.47011038661003113, + 0.01529701892286539, + -0.0033166150096803904, + 0.12170535326004028, + -0.3871544301509857, + -0.05247795954346657, + 0.4504147171974182, + -0.11442532390356064, + -0.00882577896118164, + 0.005190832540392876, + -0.05153197422623634, + 0.0055236960761249065, + 0.09320031106472015, + -0.03762076050043106, + -0.021778371185064316, + 0.00750907463952899, + 0.014965789392590523, + -0.015135630965232849, + -0.037086039781570435, + 0.08020154386758804, + -0.04429963231086731, + 0.0038218852132558823, + -0.01712334342300892, + 0.053772956132888794, + -0.05226677283644676, + -0.024439912289381027, + 0.12774989008903503, + -0.18722355365753174, + 0.0683830976486206, + -0.010828870348632336, + -0.012880662456154823, + 0.02679484151303768, + -0.13696907460689545, + 0.46868517994880676, + -0.322968989610672, + 0.052930932492017746, + 0.009463602676987648, + -0.046861011534929276, + 0.07714711129665375, + -0.35792097449302673, + 0.5517901182174683, + -0.13382655382156372, + -0.12921281158924103, + 0.018562642857432365, + -0.03842621296644211, + 0.10284601897001266, + -0.28243398666381836, + 0.13314206898212433, + 0.20769073069095612, + -0.1551610678434372, + 0.018036767840385437, + -0.03553476929664612, + 0.036686040461063385, + -0.09568552672863007, + 0.008917863480746746, + 0.11340243369340897, + -0.04745811969041824, + 0.005833764094859362, + -0.04174824804067612, + 0.022730106487870216, + 0.0013601485406979918, + -0.07473982870578766, + -0.004801879171282053, + 0.05632775276899338, + -0.04081303998827934, + 0.11509573459625244, + 0.004507652949541807, + -0.24791881442070007, + 0.43171870708465576, + -0.1362573653459549, + -0.10758046060800552, + 0.02746163308620453, + -0.2954745888710022, + 0.30186471343040466, + 0.3135572075843811, + -1.2296111583709717, + 0.8754236102104187, + -0.11699853837490082, + 0.022482017055153847, + 0.24945153295993805, + -0.7858022451400757, + 0.5181443095207214, + 1.4243930578231812, + -1.876152515411377, + 0.4689188003540039, + 0.04258054122328758, + -0.030832920223474503, + 0.9340220093727112, + -1.512351632118225, + -0.3731614947319031, + 2.021338701248169, + -0.7801089286804199, + -0.09288544207811356, + -0.12423597276210785, + -0.36861127614974976, + 1.1679530143737793, + -0.4960964024066925, + -1.0398281812667847, + 0.686152458190918, + 0.02052121050655842, + 0.07246638089418411, + -0.01763315312564373, + -0.37442535161972046, + 0.33217450976371765, + 0.22260302305221558, + -0.2657756209373474, + 0.00016369696822948754, + 0.008136127144098282, + -0.03592197597026825, + 0.022231513634324074, + 0.041430093348026276, + -0.06439317017793655, + 0.03496818616986275, + -0.05143435671925545, + 0.09930871427059174, + 0.017110232263803482, + -0.3834381699562073, + 0.44344815611839294, + -0.00280396337620914, + -0.11487428843975067, + 0.050503507256507874, + -0.22837062180042267, + 0.47540077567100525, + 0.5802375674247742, + -1.7325034141540527, + 0.8587368130683899, + 0.10429240018129349, + -0.02456486038863659, + 0.1340152472257614, + -1.2299835681915283, + 0.7986555099487305, + 2.2204456329345703, + -2.4498374462127686, + 0.33742472529411316, + 0.1001473218202591, + 0.08700849115848541, + 0.9933257102966309, + -2.5278031826019287, + -0.5935835242271423, + 2.710871934890747, + -0.87749183177948, + -0.06125229224562645, + -0.19061818718910217, + -0.04017600044608116, + 1.7519460916519165, + -0.7798219919204712, + -1.28012216091156, + 0.7500321269035339, + 0.02245335467159748, + 0.08263842761516571, + -0.1563340127468109, + -0.3502165377140045, + 0.5060794949531555, + 0.11768018454313278, + -0.2394258826971054, + 0.0027446788735687733, + -0.0012661140644922853, + 0.010839025489985943, + 0.04500429332256317, + -0.04333498701453209, + -0.027386408299207687, + 0.04357098788022995, + -0.04407481476664543, + 0.08443310111761093, + -0.08108946681022644, + -0.20346391201019287, + 0.3825778365135193, + -0.16498182713985443, + -0.04287993535399437, + 0.05340999737381935, + -0.14011172950267792, + 0.29446643590927124, + 0.2738667130470276, + -1.1299961805343628, + 0.7827413082122803, + -0.07552053779363632, + -0.03602323681116104, + 0.16167275607585907, + -0.6924317479133606, + 0.4478289783000946, + 1.2428895235061646, + -1.4833877086639404, + 0.4690392315387726, + -0.00820756796747446, + -0.09873292595148087, + 0.692342221736908, + -1.0981175899505615, + -0.3906446695327759, + 1.438644528388977, + -0.719068169593811, + 0.026173872873187065, + -0.09383898228406906, + -0.3282022774219513, + 1.0363390445709229, + -0.23960772156715393, + -0.7638148069381714, + 0.5488630533218384, + -0.015319733880460262, + 0.11911362409591675, + 0.017409542575478554, + -0.4231888949871063, + 0.23724795877933502, + 0.1191876158118248, + -0.15694500505924225, + -0.03534351661801338, + 0.06342366337776184, + 0.17738288640975952, + 0.012300643138587475, + -0.06408121436834335, + -0.06030220910906792, + 0.0018237337935715914, + 0.07659764587879181, + 0.1820947527885437, + 0.24410061538219452, + -0.06998514384031296, + -0.1491813361644745, + -0.06184092164039612, + 0.04607890918850899, + 0.15362663567066193, + 0.18308304250240326, + 0.08175522834062576, + -0.305602103471756, + -0.2915116548538208, + -0.08144206553697586, + 0.07138665020465851, + -0.03521484509110451, + -0.0914112851023674, + -0.2766699492931366, + -0.6285344362258911, + -0.38168880343437195, + -0.0033710987772792578, + 0.14477019011974335, + -0.03885374590754509, + -0.11367184668779373, + -0.1979650855064392, + -0.3575190007686615, + 0.016150522977113724, + 0.28292712569236755, + 0.2836199402809143, + -0.016672370955348015, + -0.034946177154779434, + -0.014770845882594585, + -0.0004113636096008122, + 0.29938748478889465, + 0.3562523126602173, + 0.13313128054141998, + -0.029499055817723274, + 0.007187174167484045, + 0.0636785551905632, + 0.047712039202451706, + 0.20670579373836517, + 0.10999035090208054, + -0.1150810718536377, + 0.00879934523254633, + -0.009125287644565105, + -0.013732590712606907, + 0.04738131910562515, + 0.0549951009452343, + -0.014094026759266853, + -0.01195482350885868, + -0.017125386744737625, + -0.071754589676857, + -0.023961570113897324, + 0.013098018243908882, + 0.05972208455204964, + -0.032899752259254456, + -0.024354496970772743, + -0.013116234913468361, + -0.05865325778722763, + -0.006360829807817936, + 0.12809234857559204, + 0.14038555324077606, + -0.022946689277887344, + -0.039698828011751175, + 0.05144746974110603, + -0.025034509599208832, + 0.08764739334583282, + 0.24594412744045258, + 0.19307002425193787, + -0.04085381329059601, + -0.020323628559708595, + 0.022060081362724304, + 0.01799374632537365, + 0.09039195626974106, + 0.1681770235300064, + 0.0016234283102676272, + -0.23777234554290771, + -0.11634974926710129, + -0.014439117163419724, + -0.034799374639987946, + 0.0457066111266613, + 0.049919649958610535, + -0.1926913857460022, + -0.2680967450141907, + 0.0018220803467556834, + -0.012749310582876205, + -0.04389086738228798, + 0.0060565415769815445, + -0.012036234140396118, + -0.12737582623958588, + -0.05777670815587044, + 0.09932202100753784, + 0.09969642758369446, + -0.1296343356370926, + -0.2964152693748474, + -0.05487265810370445, + 0.12073978036642075, + 0.06634647399187088, + 0.004042446613311768, + -0.1586746722459793, + -0.6267098784446716, + -0.5184157490730286, + -0.032286129891872406, + 0.28023189306259155, + 0.12663227319717407, + -0.08828771114349365, + -0.2600027620792389, + -0.5287090539932251, + -0.0994620993733406, + 0.7820600271224976, + 0.9638882279396057, + 0.2193463146686554, + -0.13466303050518036, + 0.042050741612911224, + -0.02292742393910885, + 0.7523098587989807, + 1.7435946464538574, + 1.111282229423523, + -0.2104763388633728, + -0.35129284858703613, + 0.08224371820688248, + 0.11167984455823898, + 0.6513852477073669, + 0.9696454405784607, + -0.1501394510269165, + -1.1777327060699463, + -0.7738466262817383, + 0.01114045549184084, + 0.004884988535195589, + 0.2849186658859253, + 0.14232710003852844, + -1.0306764841079712, + -1.2078118324279785, + -0.14658716320991516, + 0.036605384200811386, + 0.0001495486794738099, + 0.12111346423625946, + -0.24653346836566925, + -0.7028710246086121, + -0.18977169692516327, + 0.5171932578086853, + -0.02514370158314705, + 0.0885375589132309, + -0.1023016944527626, + 0.023200739175081253, + 0.11839435249567032, + -0.09749021381139755, + 0.008283962495625019, + 0.0106261121109128, + -0.031724803149700165, + -0.1594654619693756, + 0.433218389749527, + -0.33944255113601685, + 0.14406877756118774, + -0.0339396670460701, + 0.09370072185993195, + -0.35916459560394287, + 0.7577320337295532, + -0.5531823635101318, + -0.016844574362039566, + 0.2994873523712158, + -0.21487002074718475, + -0.16125759482383728, + 0.35567227005958557, + 0.09099612385034561, + -1.3889282941818237, + 1.9466298818588257, + -1.2556309700012207, + 0.4389301836490631, + -0.010665428824722767, + 0.4707520306110382, + -1.4310415983200073, + 2.0986156463623047, + -1.5515614748001099, + 0.3905705511569977, + 0.01881679706275463, + 0.057307951152324677, + -0.29734691977500916, + 0.369127094745636, + -0.05115725100040436, + -0.44008156657218933, + 0.48642784357070923, + -0.13904061913490295, + -0.004375698510557413, + -0.06351548433303833, + 0.256020188331604, + -0.34121274948120117, + 0.22490821778774261, + 0.004067304544150829, + -0.059063635766506195, + -0.010710661299526691, + 0.03514768183231354, + -0.08577805012464523, + 0.05103181675076485, + 0.04276616871356964, + -0.10832246392965317, + 0.03325289487838745, + 0.06318283081054688, + -0.11063538491725922, + -0.062119144946336746, + 0.40978243947029114, + -0.5597845315933228, + 0.34106317162513733, + -0.030269838869571686, + 0.057014383375644684, + -0.44329890608787537, + 1.0965592861175537, + -1.0767146348953247, + 0.13287265598773956, + 0.517289400100708, + -0.310720294713974, + -0.15501761436462402, + 0.5854693055152893, + -0.12469431757926941, + -1.7694847583770752, + 2.6433238983154297, + -1.596714735031128, + 0.3888415992259979, + -0.02415616251528263, + 0.42178481817245483, + -1.8008503913879395, + 2.8845136165618896, + -1.7628657817840576, + 0.1951047033071518, + 0.11415407806634903, + 0.07305648922920227, + -0.34212157130241394, + 0.46562451124191284, + 0.03175807744264603, + -0.7942091226577759, + 0.6133171319961548, + -0.14596694707870483, + 0.010496735572814941, + -0.03459644690155983, + 0.2948842942714691, + -0.47654271125793457, + 0.2612597346305847, + 0.016025209799408913, + -0.05287598818540573, + -0.01606004498898983, + 0.022197037935256958, + 0.028397703543305397, + -0.0390767939388752, + 0.0037972000427544117, + -0.07010228931903839, + 0.10934390872716904, + 0.017220165580511093, + 0.02215729095041752, + -0.14772991836071014, + 0.2353552132844925, + -0.3846408724784851, + 0.23990634083747864, + -0.02300707995891571, + 0.12085225433111191, + -0.3576957881450653, + 0.6410096883773804, + -0.532350480556488, + -0.002389132045209408, + 0.41821879148483276, + -0.24739143252372742, + -0.10216745734214783, + 0.16793736815452576, + 0.16367803514003754, + -1.1304419040679932, + 1.676539421081543, + -1.064436435699463, + 0.26995453238487244, + -0.07634275406599045, + 0.3324422240257263, + -1.11312997341156, + 1.8095507621765137, + -1.2477567195892334, + 0.3605581820011139, + -0.06627745926380157, + 0.008511146530508995, + -0.19528241455554962, + 0.4320055842399597, + -0.22881783545017242, + -0.18463851511478424, + 0.3064245581626892, + -0.14437103271484375, + 0.02049900032579899, + 0.018321938812732697, + 0.14011529088020325, + -0.26683253049850464, + 0.2172057181596756, + -0.12119362503290176, + 0.025965997949242592, + -0.03424325957894325, + 0.0433838777244091, + 0.1072857677936554, + 0.1997794657945633, + 0.0648089200258255, + -0.06444115936756134, + -0.13146057724952698, + 0.02106364443898201, + -0.22582228481769562, + -0.007233713287860155, + 0.18876874446868896, + -0.5612399578094482, + 0.2632557451725006, + 0.44088244438171387, + 0.11389002948999405, + -0.2791701555252075, + -0.18004432320594788, + 0.8571203947067261, + -1.9517340660095215, + -1.4906251430511475, + 0.3436146676540375, + 0.31222787499427795, + -0.20083315670490265, + -0.217665895819664, + 3.801243782043457, + 1.2014728784561157, + -0.9149202704429626, + 0.6968244910240173, + 0.12756747007369995, + -0.06783506274223328, + -2.086660385131836, + 0.5455523133277893, + 0.49095916748046875, + -0.5991013050079346, + 0.7938552498817444, + -0.1335069239139557, + 0.4730406701564789, + -1.00951087474823, + -0.537578821182251, + -0.49764835834503174, + -1.2683815956115723, + -0.045739322900772095, + -0.16049732267856598, + 0.30239275097846985, + 0.035600025206804276, + 0.6344828605651855, + 0.8256548643112183, + -0.12940075993537903, + 0.09257010370492935, + -0.11000311374664307, + 0.003206665627658367, + -0.008585316129028797, + -0.14573170244693756, + 0.172541081905365, + 0.2107972949743271, + -0.05270108953118324, + -0.08480435609817505, + 0.1914149820804596, + 0.21630872786045074, + -0.23309426009655, + -0.29484814405441284, + -0.1899339109659195, + 0.02601807750761509, + -0.05416746065020561, + 0.20924429595470428, + 0.15566189587116241, + -0.1556546688079834, + -0.23387494683265686, + -0.5112816691398621, + 0.24130745232105255, + -0.049835484474897385, + -0.2685615122318268, + -0.024764614179730415, + 0.5458847880363464, + 0.9501044750213623, + 0.1328524947166443, + 0.21218529343605042, + 0.2524968683719635, + -0.5205130577087402, + -0.3361912667751312, + 1.1678112745285034, + -0.004513490945100784, + -0.9149109125137329, + 0.2125048041343689, + 0.22423015534877777, + -0.08384363353252411, + -0.2866036593914032, + -0.20210212469100952, + -1.2377471923828125, + -0.7704879641532898, + 0.365038126707077, + -0.08308980613946915, + -0.08326874673366547, + 0.456358402967453, + 0.35142943263053894, + 0.19268833100795746, + 0.3706081509590149, + -0.04951317980885506, + 0.10151109844446182, + 0.005193099845200777, + -0.1124582439661026, + -0.08353164792060852, + -0.18709596991539001, + -0.18975794315338135, + 0.17628741264343262, + 0.05536900460720062, + 0.008301885798573494, + -0.1890449970960617, + 0.056875281035900116, + 0.7981322407722473, + -0.05872391164302826, + -0.4860122501850128, + -0.08073797076940536, + 0.13145819306373596, + -0.03608228266239166, + -0.6600452661514282, + 2.243560314178467, + 1.9288626909255981, + -0.5698518753051758, + -0.2486664056777954, + 0.42693793773651123, + 0.2667267322540283, + -4.395429611206055, + -2.15342378616333, + 0.819127082824707, + -0.9362612962722778, + -0.3760467767715454, + 0.5671858787536621, + 2.468177080154419, + -1.6694080829620361, + -0.49952322244644165, + 1.502772569656372, + -1.0188850164413452, + -0.10419629514217377, + -0.36795151233673096, + 1.2645196914672852, + 0.7223924994468689, + 1.751431941986084, + 2.018704891204834, + -0.3197852671146393, + 0.22054125368595123, + -0.19326329231262207, + -0.5307535529136658, + -0.9362435936927795, + -1.0772119760513306, + -0.19870880246162415, + -0.0650869607925415, + -0.0796947032213211, + 0.15733301639556885, + 0.08798394352197647, + 0.0010860684560611844, + 0.05327683687210083, + 0.1107875183224678, + 0.13224183022975922, + 0.08979664742946625, + 0.004348093178123236, + -0.07060158997774124, + -0.19925491511821747, + -0.15811985731124878, + -0.08220887929201126, + -0.022623460739850998, + 0.08509720861911774, + 0.00792989507317543, + -0.14345014095306396, + -0.2720486521720886, + -0.18885627388954163, + -0.11063539236783981, + -0.0355350486934185, + 0.048891279846429825, + -0.12828074395656586, + -0.2712610363960266, + -0.20134924352169037, + -0.1863398402929306, + -0.19976121187210083, + -0.09535074234008789, + 0.009852319024503231, + -0.2776590585708618, + -0.3087778687477112, + -0.21431012451648712, + -0.19772370159626007, + -0.23412325978279114, + -0.11640459299087524, + 0.09514907747507095, + -0.17561811208724976, + -0.29451555013656616, + -0.2381855845451355, + -0.18296842277050018, + -0.18682444095611572, + -0.023345205932855606, + 0.1438502073287964, + 0.02504260651767254, + -0.1554802507162094, + -0.1477985382080078, + -0.07874225080013275, + -0.002977968193590641, + 0.1048416793346405, + -0.1779504120349884, + 0.13204343616962433, + 0.14215172827243805, + 0.049610622227191925, + 0.0888131782412529, + 0.07250366359949112, + 0.0696505531668663, + 0.009899160824716091, + 0.032067786902189255, + 0.08401404321193695, + -0.03567894548177719, + -0.004740188363939524, + -0.0021664693485945463, + -0.011156522668898106, + 0.0821070745587349, + 0.10295391082763672, + -0.0017653254326432943, + -0.16915833950042725, + -0.062223054468631744, + 0.004783258773386478, + 0.038355808705091476, + 0.10124270617961884, + -0.003437258303165436, + -0.18881437182426453, + -0.15905225276947021, + -0.12576808035373688, + -0.11059725284576416, + 0.021587060764431953, + 0.07237453758716583, + -0.1706620156764984, + -0.27434206008911133, + -0.23003827035427094, + -0.20530915260314941, + -0.20856624841690063, + -0.021966496482491493, + 0.13395215570926666, + -0.03810539469122887, + -0.2409798800945282, + -0.2515420913696289, + -0.1872486174106598, + -0.15951117873191833, + 0.04223426431417465, + 0.09909931570291519, + 0.12328703701496124, + -0.057749148458242416, + -0.1300545036792755, + -0.046062104403972626, + 0.019744107499718666, + 0.09484386444091797, + -0.2709728479385376, + 0.03540695831179619, + 0.1206774190068245, + 0.057636432349681854, + 0.10385740548372269, + 0.032486993819475174, + -0.020434774458408356, + -0.10122086852788925, + -0.0023329253308475018, + 0.16941140592098236, + 0.098082534968853, + 0.1250472217798233, + 0.06134447827935219, + -0.025240115821361542, + 0.004181401338428259, + 0.14425808191299438, + 0.17515034973621368, + 0.04739757999777794, + 0.1618604063987732, + 0.1751406490802765, + 0.09162088483572006, + 0.09512057155370712, + 0.13736343383789062, + 0.028775952756404877, + 0.042535409331321716, + 0.08839954435825348, + 0.09229374676942825, + 0.1658262014389038, + 0.09852072596549988, + 0.002680110279470682, + -0.05479496717453003, + -0.03634755313396454, + -0.002902726177126169, + -0.023990361019968987, + 0.1277875006198883, + 0.12727677822113037, + 0.1002269834280014, + -0.040967896580696106, + -0.07101184874773026, + -0.007902896963059902, + 0.019561029970645905, + 0.145268052816391, + 0.017638152465224266, + 0.19240263104438782, + 0.12857146561145782, + 0.05043037235736847, + 0.11596394330263138, + 0.12513381242752075, + 0.12088746577501297, + 0.04333524778485298, + 0.05500142276287079, + 0.05169082432985306, + -0.09941842406988144, + -0.005959822330623865, + -0.032586321234703064, + -0.03065132349729538, + -0.04826900362968445, + 0.14192889630794525, + 0.2543988823890686, + 0.09563885629177094, + -0.28965362906455994, + -0.1341734230518341, + 0.033991701900959015, + -0.22402706742286682, + -0.3190857768058777, + 0.011840387247502804, + 0.9620282053947449, + 1.0609054565429688, + -0.13429726660251617, + -0.20191268622875214, + 0.05324135720729828, + -0.16234318912029266, + -0.9101927280426025, + -1.7916113138198853, + 0.3981992304325104, + 1.3173034191131592, + 0.53525310754776, + 0.18472574651241302, + 0.3719426691532135, + 0.7792536020278931, + -0.027768991887569427, + -2.245561122894287, + -1.2211185693740845, + 0.22817185521125793, + -0.0023349972907453775, + -0.12598364055156708, + 0.06836964190006256, + 0.9917387366294861, + 1.1885775327682495, + -0.2851368486881256, + -0.7428704500198364, + -0.04798422381281853, + -0.00811613816767931, + -0.19619861245155334, + -0.28184008598327637, + 0.0828644260764122, + 0.44643187522888184, + 0.1461745798587799, + -0.005575121380388737, + -0.06604957580566406, + 0.011459077708423138, + 0.03927984461188316, + 0.0634538009762764, + -0.005732079967856407, + -0.01014732290059328, + 0.07607843726873398, + 0.06948187947273254, + -0.010600326582789421, + -0.056259915232658386, + -0.24602480232715607, + -0.01649448834359646, + 0.11143466085195541, + -0.0027401424013078213, + -0.012853104621171951, + 0.08452893793582916, + 0.639316201210022, + 0.5167437195777893, + -0.2775256335735321, + -0.22241903841495514, + -0.07067711651325226, + -0.06368192285299301, + -0.4687917232513428, + -1.1776493787765503, + 0.36015447974205017, + 0.9171182513237, + 0.1905054748058319, + -0.010661551728844643, + 0.10800722986459732, + 0.5352235436439514, + 0.18558207154273987, + -1.5184046030044556, + -0.8130561709403992, + 0.15417319536209106, + 0.0713079422712326, + -0.07369451224803925, + -0.09037846326828003, + 0.6168488264083862, + 0.9663773775100708, + -0.007113471627235413, + -0.33585548400878906, + -0.02738586813211441, + 0.061310965567827225, + -0.0955657884478569, + -0.23896107077598572, + -0.1107473075389862, + 0.1830059289932251, + 0.10748914629220963, + -0.040772341191768646, + -0.05803938955068588, + -0.0004895658930763602, + 0.07664632797241211, + 0.039049405604600906, + -0.002806248841807246, + -0.02642429992556572, + 0.05169009417295456, + -0.036710865795612335, + -0.1002974808216095, + -0.12001149356365204, + -0.08043934404850006, + 0.11466419696807861, + 0.12322796136140823, + 0.07564827799797058, + 0.10148002207279205, + 0.04720174893736839, + 0.14046646654605865, + -0.0819464847445488, + -0.30803975462913513, + -0.0838734582066536, + -0.0801682323217392, + 0.05861072987318039, + 0.04970559477806091, + -0.20592759549617767, + 0.2673366665840149, + 0.2431953400373459, + -0.10027645528316498, + -0.07884806394577026, + -0.09939537942409515, + 0.1181628480553627, + 0.25269386172294617, + -0.3439132571220398, + -0.11160463094711304, + 0.08640077710151672, + 0.07200870662927628, + -0.03449570760130882, + -0.17610406875610352, + -0.021308166906237602, + 0.30556705594062805, + 0.05186203494668007, + -0.004691269714385271, + -0.005278654862195253, + 0.06289899349212646, + 0.052224051207304, + -0.05927770212292671, + -0.1586783081293106, + -0.022610770538449287, + 0.03463536128401756, + 0.004338411148637533, + 0.01452699676156044, + -0.008622901514172554, + 0.010536444373428822, + -0.038111478090286255, + 0.013373414985835552, + 0.007125865668058395, + -0.003420598339289427, + 0.03533756732940674, + 0.0320388600230217, + 0.045789655297994614, + -0.08139114826917648, + -0.03447948023676872, + -0.01453007198870182, + -0.004573625046759844, + 0.10279268026351929, + 0.10881853848695755, + 0.07537791877985, + -0.10887791216373444, + -0.0980544164776802, + -0.06889445334672928, + 0.006558350287377834, + 0.197514146566391, + 0.17890937626361847, + 0.07630149275064468, + -0.16081148386001587, + -0.16685302555561066, + -0.11421715468168259, + -0.013679573312401772, + 0.22477784752845764, + 0.20761631429195404, + 0.07321957498788834, + -0.17697854340076447, + -0.17810045182704926, + -0.1579347848892212, + -0.02679254300892353, + 0.1408146619796753, + 0.15144851803779602, + 0.08801613748073578, + -0.13237154483795166, + -0.13181765377521515, + -0.1279487907886505, + -0.01779216341674328, + 0.08145096898078918, + 0.05625852569937706, + 0.07724357396364212, + -0.04653938114643097, + -0.07479449361562729, + -0.06189379468560219, + -0.04310920089483261, + 0.02028634026646614, + -0.006228619255125523, + 0.03549303859472275, + -0.043929651379585266, + 0.007818001322448254, + 0.00874761026352644, + -0.017027731984853745, + 0.11014463752508163, + 0.0841977447271347, + 0.05960552394390106, + -0.12814101576805115, + -0.0544624924659729, + -0.045333195477724075, + 0.02336869016289711, + 0.22365787625312805, + 0.18523427844047546, + 0.09366372227668762, + -0.20144090056419373, + -0.16367222368717194, + -0.13003699481487274, + 0.0590205080807209, + 0.3301562964916229, + 0.26524844765663147, + 0.09425198286771774, + -0.26156124472618103, + -0.28513699769973755, + -0.21749621629714966, + 0.04356053099036217, + 0.35879984498023987, + 0.29898661375045776, + 0.0977487862110138, + -0.28175386786460876, + -0.2964495122432709, + -0.249031201004982, + 0.028877725824713707, + 0.26395633816719055, + 0.23059280216693878, + 0.09593978524208069, + -0.22489066421985626, + -0.2248908430337906, + -0.19214706122875214, + 0.007535146549344063, + 0.15299226343631744, + 0.09148521721363068, + 0.06946425884962082, + -0.1445557326078415, + -0.11587042361497879, + -0.0978587418794632, + -0.00984917301684618, + -0.012626220472157001, + -0.02837960794568062, + 0.02399199828505516, + -0.005340439733117819, + 0.023224178701639175, + 0.011642432771623135, + 0.003958537708967924, + 0.042965203523635864, + 0.01099414099007845, + 0.024063799530267715, + -0.0702008455991745, + 0.007805663626641035, + 0.0050195748917758465, + 0.017281856387853622, + 0.10123670846223831, + 0.06401767581701279, + 0.02626805007457733, + -0.1073761060833931, + -0.03802435100078583, + -0.014407800510525703, + -0.0006281707319431007, + 0.15516239404678345, + 0.12629136443138123, + 0.033691491931676865, + -0.17609107494354248, + -0.15251316130161285, + -0.07914211601018906, + -0.015578335151076317, + 0.18422608077526093, + 0.1740245372056961, + 0.06139932945370674, + -0.17213505506515503, + -0.1602732092142105, + -0.08922445774078369, + -0.012822975404560566, + 0.13543544709682465, + 0.12543149292469025, + 0.07651004195213318, + -0.13805902004241943, + -0.09661149233579636, + -0.052669934928417206, + -0.03268992528319359, + 0.0391642227768898, + 0.01116940937936306, + 0.04585625231266022, + -0.06474924832582474, + -0.023607701063156128, + -0.007017284631729126, + -0.026150476187467575, + 0.05729387328028679, + -0.10095079243183136, + 0.16617903113365173, + -0.13664309680461884, + 0.026482274755835533, + 0.008411461487412453, + -0.03410203382372856, + 0.022963764145970345, + 0.008903563022613525, + 0.11244194954633713, + -0.20863348245620728, + 0.11064451932907104, + -0.024916114285588264, + 0.009591493755578995, + -0.26092270016670227, + 0.5717483758926392, + -0.38539814949035645, + 0.035056713968515396, + 0.08623965084552765, + -0.016184961423277855, + 0.11129201203584671, + -0.6138678789138794, + 1.3646206855773926, + -1.4969615936279297, + 0.8465064764022827, + -0.2794847786426544, + 0.05826558917760849, + 0.07709132134914398, + -0.5444677472114563, + 1.3013663291931152, + -1.5686073303222656, + 0.9930508732795715, + -0.39188963174819946, + 0.08085884898900986, + -0.05875617265701294, + 0.03498996049165726, + 0.23967482149600983, + -0.3468690514564514, + 0.19146253168582916, + 0.019604403525590897, + -0.027150027453899384, + -0.024670494720339775, + 0.09944183379411697, + -0.11718503385782242, + 0.09772855788469315, + -0.11857263743877411, + 0.09660946577787399, + -0.03638811036944389, + -0.0295167975127697, + 0.1032838523387909, + -0.12557579576969147, + 0.11812210828065872, + -0.08446288853883743, + 0.027706580236554146, + 0.010997293516993523, + -0.06348618865013123, + 0.09578556567430496, + -0.0165568757802248, + -0.014778072014451027, + -0.07772849500179291, + 0.11245536059141159, + -0.043248821049928665, + 0.013345679268240929, + -0.22149333357810974, + 0.6456363797187805, + -0.7280437350273132, + 0.3046833574771881, + 0.06304280459880829, + -0.07310052216053009, + 0.08824795484542847, + -0.65179842710495, + 1.6453673839569092, + -2.046448230743408, + 1.3267604112625122, + -0.42399832606315613, + 0.0010522910160943866, + 0.07953720539808273, + -0.5960973501205444, + 1.5601089000701904, + -2.084894895553589, + 1.4612183570861816, + -0.5491638779640198, + 0.13709494471549988, + -0.09170618653297424, + 0.07287970930337906, + 0.24422486126422882, + -0.4581631124019623, + 0.29479551315307617, + -0.07515113800764084, + -0.012292998842895031, + -0.04451148584485054, + 0.14961428940296173, + -0.15577177703380585, + 0.06323063373565674, + -0.07806269824504852, + 0.07061618566513062, + -0.026793144643306732, + -0.051938362419605255, + 0.13946141302585602, + -0.14129231870174408, + 0.11092118173837662, + -0.08889970183372498, + 0.034787945449352264, + -0.008983314968645573, + -0.04930088296532631, + 0.09856640547513962, + -0.09350966662168503, + 0.07015673816204071, + -0.06468848884105682, + 0.08028972148895264, + -0.02378295361995697, + 0.004251216538250446, + -0.11239825189113617, + 0.2660067081451416, + -0.367576539516449, + 0.2212517410516739, + -0.035011082887649536, + -0.037866897881031036, + 0.11835235357284546, + -0.4868132174015045, + 0.9402765035629272, + -1.0933791399002075, + 0.9518744349479675, + -0.5096855759620667, + 0.12277142703533173, + 0.12916085124015808, + -0.4648635983467102, + 0.8895858526229858, + -1.0776352882385254, + 1.023865818977356, + -0.5914785861968994, + 0.1682877242565155, + -0.05646277964115143, + 0.04132156819105148, + -0.01790236309170723, + -0.059831030666828156, + 0.10092897713184357, + -0.1268356889486313, + 0.013669619336724281, + -0.02746082842350006, + 0.11544085294008255, + -0.2124193012714386, + 0.2733248472213745, + -0.1360178142786026, + 0.025302443653345108, + 0.01249375008046627, + -0.015119954012334347, + 0.017966970801353455, + 0.00269943755120039, + 0.014392177574336529, + 0.007648292928934097, + 0.011665135622024536, + -0.006192799191921949, + 0.004215092398226261, + 0.017718149349093437, + 0.046436555683612823, + 0.044417623430490494, + 0.01518242433667183, + -0.0020157198887318373, + -0.01828707568347454, + -0.029163505882024765, + -0.03131464868783951, + -0.004393945913761854, + 0.048599082976579666, + 0.015757638961076736, + -0.015650734305381775, + -0.002684049541130662, + -0.0697445422410965, + -0.25050923228263855, + -0.4758685231208801, + -0.5382962822914124, + -0.38907238841056824, + -0.12599025666713715, + -0.00266047241166234, + 0.0758173018693924, + 0.26593172550201416, + 0.4203726053237915, + 0.4958920478820801, + 0.3697706162929535, + 0.12434400618076324, + 0.026325728744268417, + 0.022295912727713585, + 0.08135133236646652, + 0.2627769708633423, + 0.26325660943984985, + 0.12326934933662415, + 0.058665141463279724, + 0.04346219077706337, + -0.0013142779935151339, + -0.10037153959274292, + -0.27075886726379395, + -0.28071707487106323, + -0.17300420999526978, + -0.06914675980806351, + 0.004067219793796539, + -0.020674005150794983, + 0.02103183977305889, + 0.0033879741095006466, + 0.013523808680474758, + -0.007318845018744469, + -0.009975744411349297, + -0.02981705591082573, + 0.023193644359707832, + 0.09624253213405609, + 0.1077117845416069, + 0.11186518520116806, + 0.07592211663722992, + 0.04614634811878204, + 0.015908582136034966, + -0.05212458223104477, + -0.1262977123260498, + -0.10974782705307007, + -0.07645918428897858, + -0.06987964361906052, + -0.08783216774463654, + -0.046172842383384705, + -0.22593465447425842, + -0.5281140804290771, + -0.8424770832061768, + -0.9608982801437378, + -0.7363743185997009, + -0.3312055170536041, + -0.10426472127437592, + 0.24067367613315582, + 0.5504152178764343, + 0.81276935338974, + 0.9592635035514832, + 0.7479950785636902, + 0.32608768343925476, + 0.14525265991687775, + 0.15008939802646637, + 0.32246851921081543, + 0.5287250876426697, + 0.5817036032676697, + 0.37340155243873596, + 0.20366452634334564, + 0.1546182781457901, + -0.11224830150604248, + -0.29856279492378235, + -0.5281672477722168, + -0.5890122056007385, + -0.4024880528450012, + -0.23706914484500885, + -0.0641399398446083, + -0.0025121152866631746, + 0.0051757702603936195, + -0.014290476217865944, + 0.0043721878901124, + -0.004783981014043093, + 0.021787043660879135, + -0.004969750996679068, + -0.022116241976618767, + 0.05208030343055725, + 0.07022145390510559, + 0.03730607405304909, + 0.03242917358875275, + 0.04344351217150688, + -0.01189794484525919, + -0.0418211966753006, + -0.059125497937202454, + -0.014576594345271587, + 0.01294493954628706, + -0.011262460611760616, + -0.059920165687799454, + -0.04733816161751747, + -0.12665517628192902, + -0.29677024483680725, + -0.5247481465339661, + -0.6474934816360474, + -0.4751538038253784, + -0.1937171369791031, + -0.05117221921682358, + 0.14646948873996735, + 0.32891425490379333, + 0.5415402054786682, + 0.6071264147758484, + 0.4653589427471161, + 0.18045872449874878, + 0.09937354922294617, + 0.1264665126800537, + 0.18507222831249237, + 0.31783968210220337, + 0.3545042872428894, + 0.22468777000904083, + 0.09973976761102676, + 0.1227618008852005, + -0.07824759930372238, + -0.20465101301670074, + -0.36476215720176697, + -0.38243186473846436, + -0.2540777623653412, + -0.13525226712226868, + -0.03621843457221985, + -0.012233156710863113, + -0.01481863297522068, + -0.04313792288303375, + 0.002874002791941166, + -0.028444716706871986, + -0.04687628522515297, + -0.026806645095348358, + -0.0228339321911335, + -0.015892738476395607, + -0.015550780110061169, + 0.07011140882968903, + 0.0017389585264027119, + -0.05721491947770119, + -0.017484690994024277, + -0.03954736143350601, + -0.006339249666780233, + 0.08166316151618958, + 0.37439921498298645, + 0.2830294966697693, + 0.00668215099722147, + -0.038873329758644104, + -0.012295035645365715, + 0.04932165890932083, + 0.31826695799827576, + 0.8449289202690125, + 0.7123299241065979, + 0.2574000954627991, + 0.04747961834073067, + -0.04416817054152489, + -0.005029442720115185, + 0.2027042657136917, + 0.6639980673789978, + 0.6243636012077332, + 0.21359916031360626, + 0.027929672971367836, + -0.05395142361521721, + -0.04981911554932594, + -0.006375179626047611, + 0.23660773038864136, + 0.2155737280845642, + 0.020577391609549522, + -0.032118700444698334, + -0.02332071214914322, + -0.009217707440257072, + -0.038096409291028976, + 0.05811609327793121, + 0.03776064142584801, + -0.03570764884352684, + -0.042420413345098495, + 0.017812976613640785, + 0.019242385402321815, + 0.030057156458497047, + 0.003040613606572151, + 0.02378096617758274, + 0.04043402150273323, + 0.0243258997797966, + 0.014026327058672905, + 0.005650558043271303, + -0.002831381279975176, + -0.0645776093006134, + -0.03761167451739311, + 0.043774381279945374, + 0.010685136541724205, + 0.031011218205094337, + -0.0025828774087131023, + -0.11959855258464813, + -0.3524792194366455, + -0.30037227272987366, + -0.053334690630435944, + 0.009859252721071243, + 0.0010005333460867405, + -0.04819931834936142, + -0.3154168128967285, + -0.7240553498268127, + -0.6380828022956848, + -0.25695785880088806, + -0.06639125943183899, + 0.03295261785387993, + -0.012727363035082817, + -0.24232468008995056, + -0.6055921912193298, + -0.5679556727409363, + -0.20067356526851654, + -0.03628019988536835, + 0.04774145409464836, + 0.029560575261712074, + -0.038632482290267944, + -0.24032950401306152, + -0.2095729559659958, + -0.006905315909534693, + 0.02563827484846115, + 0.03053808957338333, + 0.0012747920118272305, + 0.004095789045095444, + -0.07932732999324799, + -0.046672020107507706, + 0.02153847925364971, + 0.019504766911268234, + -0.006118285935372114, + 0.0026654782705008984, + 0.013819373212754726, + -0.01078135147690773, + 0.0070082321763038635, + 0.00906399916857481, + 0.010149766691029072, + 0.000516490894369781, + 0.00034157291520386934, + 0.02412085421383381, + 0.006926041562110186, + 0.023299943655729294, + 0.01129852794110775, + -0.0018704778049141169, + 0.016042279079556465, + 0.023886069655418396, + 0.04207555204629898, + -0.0021778997033834457, + 0.041684601455926895, + 0.05059140920639038, + 0.03518521040678024, + -0.0032736151479184628, + -0.0007146652205847204, + 0.015503454953432083, + -0.11896659433841705, + -0.07006713002920151, + 0.007565992418676615, + 0.012584990821778774, + 0.00843358226120472, + 0.017024952918291092, + 0.0359124094247818, + -0.05997823178768158, + -0.04116949439048767, + -0.016472430899739265, + 0.002696823561564088, + 0.00829327292740345, + 0.016238784417510033, + 0.0455794483423233, + 0.0019872160628437996, + -0.005927432328462601, + -0.003552153240889311, + 0.020063765347003937, + 0.00010026743984781206, + 0.01045019831508398, + 0.034689340740442276, + 0.014206668362021446, + 0.015128945000469685, + 0.00972809735685587, + 0.019944868981838226, + 0.020581791177392006, + 0.02938947267830372, + 0.03923909366130829, + 0.03601628914475441, + 0.030168617144227028, + 0.05403255671262741, + 0.03985666483640671, + 0.020015308633446693, + 0.0285494402050972, + 0.013555807992815971, + -0.04409409686923027, + -0.07503483444452286, + 0.01716756261885166, + 0.02053452841937542, + 0.057520389556884766, + 0.02973104454576969, + -0.04563397541642189, + -0.2676408588886261, + -0.30933722853660583, + -0.11671236902475357, + 0.0020135289523750544, + 0.022801443934440613, + -0.03161352127790451, + -0.2704106271266937, + -0.5803710222244263, + -0.5762420296669006, + -0.30449461936950684, + -0.0780220776796341, + 0.017343536019325256, + -0.05319945886731148, + -0.2906038463115692, + -0.598426342010498, + -0.5925986766815186, + -0.31852787733078003, + -0.09950074553489685, + 0.05888299271464348, + 0.01939479075372219, + -0.1060815081000328, + -0.3505017161369324, + -0.3200446665287018, + -0.10609738528728485, + 0.03659524768590927, + 0.056114207953214645, + 0.03447861596941948, + 0.014380007050931454, + -0.09436371922492981, + -0.07562272250652313, + 0.04223132133483887, + 0.06327345967292786, + -0.03735652193427086, + -0.052881840616464615, + -0.058017320930957794, + -0.02474917098879814, + -0.02431381866335869, + -0.0629878118634224, + -0.05212349444627762, + -0.03820814937353134, + -0.0034579068887978792, + -0.004930540919303894, + 0.07968354970216751, + 0.07278168946504593, + 0.015167324803769588, + -0.013638288713991642, + -0.05875609815120697, + -0.008851750753819942, + 0.10708516091108322, + 0.33075177669525146, + 0.3502756953239441, + 0.14791442453861237, + 0.03131852671504021, + -0.028764141723513603, + 0.07454497367143631, + 0.3000347316265106, + 0.6147283315658569, + 0.6289594173431396, + 0.3398674726486206, + 0.13494613766670227, + -0.03705109655857086, + 0.0633230209350586, + 0.3147434592247009, + 0.595033586025238, + 0.594217836856842, + 0.33864542841911316, + 0.11264053732156754, + -0.059276629239320755, + 0.005206871312111616, + 0.14524762332439423, + 0.37473905086517334, + 0.34477534890174866, + 0.12632343173027039, + 0.011062734760344028, + -0.06149457022547722, + -0.028670497238636017, + 0.011082210578024387, + 0.13112866878509521, + 0.1106843650341034, + -0.0025933771394193172, + -0.03781202808022499, + 0.030325254425406456, + 0.017758814617991447, + 0.01635698974132538, + -0.008786264806985855, + -0.0005018062074668705, + 0.005934061016887426, + 0.020206287503242493, + 0.019497420638799667, + -0.01290479488670826, + -0.010817185044288635, + -0.032760608941316605, + -0.026973316445946693, + -0.0021766452118754387, + -0.012848617509007454, + -0.0002560729335527867, + -0.02383977733552456, + -0.05322824791073799, + -0.05382781848311424, + -0.04459262639284134, + -0.04581240937113762, + -0.03465775027871132, + 0.0026904877740889788, + -0.026097090914845467, + -0.05170493200421333, + -0.04981262609362602, + -0.05221042037010193, + -0.05268307775259018, + -0.04735802114009857, + 0.019142162054777145, + -0.019374292343854904, + -0.03312355652451515, + -0.04133244976401329, + -0.033129844814538956, + -0.01844680868089199, + -0.024726904928684235, + 0.0012146441731601954, + -0.025521529838442802, + -0.03120318427681923, + -0.04863203689455986, + -0.021450525149703026, + -0.04190714284777641, + -0.02833862416446209, + 0.017827404662966728, + -0.010181388817727566, + -0.020994380116462708, + -0.04290826618671417, + -0.031555648893117905, + -0.030525390058755875, + -0.024981478229165077, + -0.017512500286102295, + 0.019927235320210457, + 0.00433371402323246, + -0.009276121854782104, + -0.03990143537521362, + -0.021251117810606956, + 0.017825132235884666, + -0.02313065528869629, + 0.012881814502179623, + 0.0009175563463941216, + -0.0656605213880539, + -0.007037178613245487, + 0.023603176698088646, + 0.04873553663492203, + 0.013912673108279705, + 9.78652315097861e-05, + -0.03166677802801132, + -0.11772678792476654, + -0.034320034086704254, + 0.04952533170580864, + 0.10113520920276642, + 0.030472615733742714, + -0.05131377652287483, + -0.1371452510356903, + -0.2326214611530304, + -0.0629519522190094, + 0.12444627285003662, + 0.15845368802547455, + 0.014535457827150822, + -0.06888624280691147, + -0.18798232078552246, + -0.24720685184001923, + -0.04858007654547691, + 0.26889580488204956, + 0.2433905005455017, + -0.01772989332675934, + -0.06027546152472496, + -0.12164203822612762, + -0.20018024742603302, + 0.0035393801517784595, + 0.27190765738487244, + 0.1929154396057129, + -0.012923460453748703, + -0.013931642286479473, + -0.043986693024635315, + -0.0655391663312912, + 0.04751605913043022, + 0.13482201099395752, + 0.06690078228712082, + -0.01862635649740696, + 0.02938506379723549, + 0.01789080537855625, + -0.006509440019726753, + -0.029202938079833984, + -0.023693149909377098, + 0.01042762491852045, + -0.0035929735749959946, + 0.024952176958322525, + -0.013459124602377415, + -0.10798560827970505, + -0.020217353478074074, + 0.017876077443361282, + 0.07628928124904633, + 0.04444783553481102, + 0.012667268514633179, + -0.09012818336486816, + -0.22452381253242493, + -0.07556752860546112, + 0.07942477613687515, + 0.17035256326198578, + 0.0396822914481163, + -0.08236342668533325, + -0.23916372656822205, + -0.3645225763320923, + -0.10748416185379028, + 0.1996970921754837, + 0.3076043725013733, + -0.0033923503942787647, + -0.13259321451187134, + -0.28894615173339844, + -0.3605952262878418, + -0.07969008386135101, + 0.3583948314189911, + 0.4267900586128235, + -0.02228585258126259, + -0.11386624723672867, + -0.21445821225643158, + -0.26956692337989807, + 0.026791207492351532, + 0.37918713688850403, + 0.37130093574523926, + -0.05172214284539223, + -0.05132569745182991, + -0.07469630241394043, + -0.11400169134140015, + 0.07863093167543411, + 0.24061299860477448, + 0.19393151998519897, + -0.03217098489403725, + 0.013085477985441685, + 0.032348379492759705, + 0.03207695484161377, + 0.010604938492178917, + -0.026534704491496086, + -0.018284842371940613, + -0.01768680103123188, + -0.001516501884907484, + 0.013829287141561508, + -0.034318119287490845, + 0.015753330662846565, + -0.0018936718115583062, + 0.014737343415617943, + 0.03306088596582413, + 0.020835628733038902, + -0.03396771103143692, + -0.10758449137210846, + -0.03052518330514431, + 0.020080547779798508, + 0.06180800125002861, + 0.03735671192407608, + -0.037925880402326584, + -0.09720461815595627, + -0.21495617926120758, + -0.06842153519392014, + 0.08532039076089859, + 0.13350333273410797, + 0.03649023920297623, + -0.03904158994555473, + -0.1483580619096756, + -0.2068314403295517, + -0.05687328055500984, + 0.21108660101890564, + 0.21018920838832855, + 0.009318819269537926, + -0.037683792412281036, + -0.09845960140228271, + -0.1535443514585495, + 0.004504916723817587, + 0.20256847143173218, + 0.1799001693725586, + -0.03175490349531174, + -0.020391397178173065, + -0.007309200707823038, + -0.06765769422054291, + 0.013149870559573174, + 0.08469820767641068, + 0.04147877171635628, + -0.0027241194620728493, + 0.008016721345484257, + 0.001382349175401032, + 0.0001219741752720438, + -0.059255484491586685, + -0.03761141747236252, + 0.0381690077483654, + -0.01603613793849945, + 0.0017731477273628116, + -0.016544193029403687, + 0.09518970549106598, + 0.1735895872116089, + 0.005558829288929701, + -0.13464735448360443, + -0.0703420490026474, + 0.001990854274481535, + -0.03426021337509155, + -0.4390500485897064, + -0.11292288452386856, + 0.20430812239646912, + 0.14832687377929688, + 0.06074441969394684, + -0.03749264031648636, + 0.408058226108551, + 0.43119552731513977, + -0.3804298937320709, + -0.3694773018360138, + -0.03696960583329201, + 0.04022200033068657, + -0.0812998041510582, + -0.4322642385959625, + 0.19638888537883759, + 0.7809834480285645, + 0.11584538966417313, + -0.04975399747490883, + -0.015579828992486, + 0.1362757831811905, + 0.027220597490668297, + -0.4703449606895447, + -0.3726261258125305, + 0.11754196882247925, + -0.01204066164791584, + -0.00118898821529001, + -0.05152498185634613, + 0.08767394721508026, + 0.14183296263217926, + 0.01692730002105236, + -0.04587334021925926, + 0.011115594767034054, + 0.021572716534137726, + -0.021584773436188698, + -0.012763801962137222, + 0.05708793178200722, + 0.021982798352837563, + -0.02731800265610218, + 0.03000856563448906, + 0.006653181277215481, + -0.02485630102455616, + -0.20296195149421692, + -0.10483214259147644, + 0.20483383536338806, + 0.1350196748971939, + -0.08543248474597931, + 0.02644401416182518, + 0.26855263113975525, + 0.1071053072810173, + -0.8168368935585022, + -0.6617473363876343, + 0.02877889946103096, + 0.21807144582271576, + -0.02164696715772152, + -0.03712613880634308, + 0.9743875861167908, + 1.1631361246109009, + -0.45643851161003113, + -0.8180081844329834, + -0.28109386563301086, + -0.09115415811538696, + -0.4352502226829529, + -0.7433719038963318, + 0.5383746027946472, + 1.7271664142608643, + 0.509749174118042, + -0.0689467042684555, + 0.010011479258537292, + 0.11752951890230179, + -0.28825971484184265, + -1.113126277923584, + -0.6029489636421204, + 0.357056587934494, + 0.19766344130039215, + 0.023361098021268845, + 0.04305602237582207, + 0.24867205321788788, + 0.16359609365463257, + -0.2485191822052002, + -0.2251967489719391, + 0.030422789976000786, + 0.0049157580360770226, + -0.05497031658887863, + -0.030760835856199265, + 0.034536562860012054, + 0.019565051421523094, + -0.00933124776929617, + 0.01611645519733429, + 0.07988770306110382, + -0.021982649341225624, + -0.21876110136508942, + -0.10555483400821686, + 0.1893070936203003, + 0.14684906601905823, + -0.031080693006515503, + 0.09768003225326538, + 0.3261844515800476, + 0.1466774046421051, + -0.6738073825836182, + -0.5424039363861084, + 0.04689512774348259, + 0.22039148211479187, + -0.07084018737077713, + -0.07436021417379379, + 0.8260523080825806, + 1.0253428220748901, + -0.38162854313850403, + -0.727206289768219, + -0.2605172097682953, + -0.0996573269367218, + -0.3653049170970917, + -0.6791687607765198, + 0.43514078855514526, + 1.4186147451400757, + 0.38797008991241455, + -0.12675431370735168, + 0.02766786515712738, + 0.14237603545188904, + -0.2306709885597229, + -0.9204807877540588, + -0.5071616172790527, + 0.32662850618362427, + 0.20703284442424774, + -0.020968681201338768, + 0.014105334877967834, + 0.24642448127269745, + 0.20103473961353302, + -0.15519124269485474, + -0.22072142362594604, + 0.049920063465833664, + -0.05465548485517502, + 0.018651481717824936, + 0.030082669109106064, + 0.05234164372086525, + 0.10243640840053558, + 0.03569166734814644, + 0.038984544575214386, + 0.05248976871371269, + 0.24501988291740417, + 0.4674161374568939, + 0.7142530083656311, + 0.7423628568649292, + 0.6262048482894897, + 0.4019012451171875, + -0.010997634381055832, + 0.17266513407230377, + 0.4467124342918396, + 0.7795005440711975, + 0.8282667994499207, + 0.6824804544448853, + 0.3955397605895996, + 0.009771074168384075, + 0.10707246512174606, + 0.23039454221725464, + 0.33151063323020935, + 0.36120596528053284, + 0.3240644633769989, + 0.17939962446689606, + -0.01115038525313139, + -0.11081521213054657, + -0.2146066278219223, + -0.3572347164154053, + -0.44021451473236084, + -0.38320258259773254, + -0.24643990397453308, + 0.031578775495290756, + -0.21325217187404633, + -0.4312629997730255, + -0.7276368141174316, + -0.8273008465766907, + -0.718246340751648, + -0.4161607027053833, + -0.06636986136436462, + -0.28078269958496094, + -0.476252943277359, + -0.734549880027771, + -0.7796792984008789, + -0.6637035608291626, + -0.41896238923072815, + 0.021693198010325432, + 0.006199972704052925, + -0.016619624570012093, + -0.010678192600607872, + 0.012267512269318104, + 0.004102918319404125, + -0.004080160986632109, + -0.0029241242446005344, + -0.027252744883298874, + -0.0772257149219513, + -0.09107967466115952, + -0.11302012205123901, + -0.08569496124982834, + -0.07242150604724884, + -0.016465697437524796, + -0.04874062165617943, + -0.09103028476238251, + -0.09025602042675018, + -0.07523388415575027, + -0.06320428103208542, + -0.048220545053482056, + -0.028701437637209892, + -0.008647853508591652, + -0.022354092448949814, + -0.06076030433177948, + -0.030872423201799393, + -0.045786645263433456, + -0.04190178960561752, + 0.03718986362218857, + 0.021405767649412155, + 0.007675759959965944, + 0.02794131636619568, + 0.030316906049847603, + 0.007403802592307329, + 0.04861852154135704, + 0.023217258974909782, + 0.04545973241329193, + 0.07504793256521225, + 0.06824314594268799, + 0.07417462021112442, + 0.0769289955496788, + 0.0766506940126419, + -0.0028638055082410574, + 0.05911175534129143, + 0.055706772953271866, + 0.10735032707452774, + 0.10494870692491531, + 0.11092723160982132, + 0.09338293969631195, + 0.04235343262553215, + -0.022347571328282356, + -0.026347652077674866, + -0.06954608112573624, + -0.06944439560174942, + -0.05570404976606369, + -0.042987462133169174, + -0.056951191276311874, + -0.2151203453540802, + -0.3603246510028839, + -0.5899456143379211, + -0.6453464031219482, + -0.5338351726531982, + -0.31790611147880554, + 0.049492284655570984, + -0.12898015975952148, + -0.40155911445617676, + -0.6737278699874878, + -0.7170611619949341, + -0.5817899703979492, + -0.32979026436805725, + -0.005899591837078333, + -0.07673019915819168, + -0.190496027469635, + -0.34019437432289124, + -0.3314637243747711, + -0.2796767055988312, + -0.1381818801164627, + -0.008025999180972576, + 0.08429048955440521, + 0.2105528861284256, + 0.3415210545063019, + 0.4151126444339752, + 0.34003961086273193, + 0.21059827506542206, + -0.03514896333217621, + 0.1792585551738739, + 0.3903186321258545, + 0.6413942575454712, + 0.7557680010795593, + 0.6069726943969727, + 0.3415443003177643, + 0.03447553142905235, + 0.21517080068588257, + 0.4215562045574188, + 0.6151171922683716, + 0.6550290584564209, + 0.5680058002471924, + 0.33561068773269653, + -0.12205997854471207, + -0.0038300298620015383, + 0.3281119763851166, + -0.2328944057226181, + -0.03834507241845131, + 0.05432930961251259, + -0.014430212788283825, + 0.006271198857575655, + 0.32864242792129517, + 0.47277259826660156, + -0.5593215227127075, + -0.14971251785755157, + 0.13066314160823822, + -0.09738356620073318, + 0.2966129779815674, + 0.5606555342674255, + -0.3184640407562256, + -2.022890090942383, + -0.361995667219162, + 0.5496177673339844, + 0.02796279452741146, + -0.21818380057811737, + -0.5373459458351135, + -1.9538941383361816, + -1.9984712600708008, + 1.6747761964797974, + 1.5063239336013794, + -0.24534250795841217, + -0.040306344628334045, + -0.16963164508342743, + -0.40690454840660095, + 1.3548375368118286, + 3.922116279602051, + 0.8723023533821106, + -0.8986141681671143, + 0.06912416964769363, + 0.2192920595407486, + 0.352949321269989, + 1.2243634462356567, + 1.1395865678787231, + -1.5146961212158203, + -1.1557590961456299, + -0.05440744385123253, + -0.04629289731383324, + -0.002693743444979191, + -0.21906790137290955, + -0.5464610457420349, + -1.1933224201202393, + 0.01913866586983204, + 0.09363497048616409, + -0.06080613285303116, + -0.049100056290626526, + 0.04482033848762512, + -0.04087500274181366, + -0.009318803437054157, + 0.009458474814891815, + -0.09565524011850357, + -0.2264278084039688, + -0.0698866918683052, + 0.13825084269046783, + 0.014815542846918106, + -0.05801662430167198, + 0.012776852585375309, + -0.0753035843372345, + -0.07555855065584183, + 0.484436959028244, + 0.6397283673286438, + 0.12687323987483978, + -0.01779526099562645, + 0.05689511448144913, + 0.06747376173734665, + 0.26353734731674194, + 0.5908273458480835, + 0.4315526783466339, + -0.5426794290542603, + -0.44501280784606934, + -0.019558124244213104, + -0.03320806100964546, + -0.025809556245803833, + 0.17376014590263367, + -0.5201969742774963, + -1.2842578887939453, + -0.3674038052558899, + 0.0882175862789154, + -0.030023137107491493, + -0.1173325777053833, + 0.02555503323674202, + -0.39882710576057434, + -0.37364596128463745, + 0.3550366163253784, + 0.3903135359287262, + 0.04022252932190895, + 0.016731394454836845, + 0.11207644641399384, + -0.020967213436961174, + -0.028497911989688873, + 0.37590932846069336, + 0.14920172095298767, + 0.029958104714751244, + 0.039632707834243774, + -0.24969367682933807, + 0.16809938848018646, + 0.07703239470720291, + -0.03522319719195366, + -0.007072617299854755, + 0.07751759141683578, + -0.06782346963882446, + -0.4010501801967621, + 0.41269779205322266, + 0.1311105638742447, + -0.07331988960504532, + 0.08240311592817307, + -0.20034979283809662, + -0.4718745946884155, + -0.178948312997818, + 1.3285318613052368, + 0.20384186506271362, + -0.48546233773231506, + -0.09941625595092773, + 0.13249020278453827, + 0.29977336525917053, + 1.2681238651275635, + 1.5725642442703247, + -1.0834472179412842, + -1.0335719585418701, + 0.25975045561790466, + 0.06584863364696503, + 0.1609305590391159, + 0.25940945744514465, + -0.8426372408866882, + -2.590407609939575, + -0.4723183214664459, + 0.7581043243408203, + -0.03634117543697357, + -0.10199672728776932, + -0.3744191527366638, + -0.7823801636695862, + -0.7062401175498962, + 1.116550087928772, + 0.7735803127288818, + 0.012776976451277733, + 0.034575968980789185, + -0.10188565403223038, + 0.2212170958518982, + 0.5182898044586182, + 0.8056022524833679, + -0.1897655427455902, + -0.005556725896894932, + -0.003909373190253973, + -0.02175678312778473, + -0.04085654392838478, + -0.03573022410273552, + -0.0038509985897690058, + 0.02454996667802334, + 0.039437733590602875, + 0.02077251859009266, + 0.02166259102523327, + 0.17245841026306152, + 0.09513862431049347, + -0.10491111874580383, + -0.08084940910339355, + -0.026179829612374306, + 0.0215831957757473, + -0.16602416336536407, + -0.2803819179534912, + 0.23894084990024567, + 0.3269801735877991, + 0.04504352807998657, + 0.0009768904419615865, + 0.01959501951932907, + 0.24426960945129395, + -0.1451571136713028, + -0.5944203734397888, + -0.17875447869300842, + 0.028336334973573685, + 0.004323791246861219, + -0.045389141887426376, + 0.0343034490942955, + 0.46665430068969727, + 0.3707427978515625, + -0.114569291472435, + 0.04335101321339607, + -0.018011711537837982, + -0.021181274205446243, + -0.19074901938438416, + -0.20113815367221832, + 0.048786211758852005, + 0.08533122390508652, + -0.06084573268890381, + 0.01217757910490036, + 0.030666939914226532, + 0.05272842198610306, + 0.010849648155272007, + -0.05913804844021797, + -0.04202868044376373, + -0.0015147016383707523, + -0.03421122953295708, + 0.015080726705491543, + 0.12191007286310196, + 0.10450142621994019, + -0.04972418025135994, + -0.07557133585214615, + -0.02221665158867836, + -0.0861242413520813, + -0.14919178187847137, + -0.04388582333922386, + 0.4605262875556946, + 0.5697804093360901, + 0.1583399623632431, + -0.045628566294908524, + -0.05220475420355797, + -0.13630147278308868, + -0.7103163599967957, + -1.0178179740905762, + 0.1927143931388855, + 0.7479860186576843, + 0.47013771533966064, + 0.16943301260471344, + 0.2398149073123932, + 0.4710526168346405, + -0.5974176526069641, + -1.8564051389694214, + -0.7726883292198181, + 0.05584309995174408, + 0.08902852982282639, + 0.0931839719414711, + 0.46213099360466003, + 1.2080260515213013, + 0.6001025438308716, + -0.590207576751709, + -0.4145379662513733, + -0.04529324173927307, + -0.08303339034318924, + -0.2470429688692093, + -0.03481363505125046, + 0.4808541238307953, + 0.4001348614692688, + -0.1292688548564911, + -0.03635162487626076, + -0.006270444020628929, + -0.0314505510032177, + -0.13043232262134552, + -0.10837803781032562, + 0.10718243569135666, + 0.07523836195468903, + -0.00597786670550704, + 0.06580565124750137, + 0.11166563630104065, + 0.021869506686925888, + -0.10510984063148499, + -0.07651247084140778, + 0.01229890063405037, + -0.08976037800312042, + -0.14929910004138947, + -0.018859578296542168, + 0.4408939778804779, + 0.4029107689857483, + -0.05015433207154274, + -0.13887189328670502, + -0.04514491930603981, + -0.07346425950527191, + -0.5277182459831238, + -0.7335640788078308, + 0.24182197451591492, + 0.626846432685852, + 0.23399080336093903, + 0.09675730019807816, + 0.15529058873653412, + 0.42680656909942627, + -0.4012089967727661, + -1.3605350255966187, + -0.4793834686279297, + 0.10987094044685364, + 0.07592830061912537, + 0.003319029463455081, + 0.24004696309566498, + 0.9590277671813965, + 0.4946591258049011, + -0.4889579117298126, + -0.34744441509246826, + -0.020535729825496674, + -0.026767954230308533, + -0.2090117186307907, + -0.11841326951980591, + 0.37452432513237, + 0.39960840344429016, + -0.07025045901536942, + -0.022984744980931282, + 0.022319970652461052, + -0.0027356306090950966, + -0.13681942224502563, + -0.09797768294811249, + 0.09914079308509827, + 0.10856777429580688, + ] + value = numpy.array(list_value, dtype=numpy.float32).reshape((64, 3, 7, 7)) + tensor = numpy_helper.from_array(value, name="onnx::Conv_501") + + initializers.append(tensor) + + list_value = [ + 3.085598945617676, + 2.2436060905456543, + 4.244357585906982, + 1.4069645404815674, + -4.00622034072876, + 2.595770835876465, + 2.7202603816986084, + 2.4405417442321777, + 1.1759933233261108, + 2.021026372909546, + 2.6628992557525635, + 6.445226192474365, + -7.029932498931885, + 1.1305793523788452, + 2.537140369415283, + 5.456772327423096, + 4.780154705047607, + 10.039976119995117, + 2.912492275238037, + 15.781542778015137, + 2.5154318809509277, + 2.628824472427368, + 2.2992050647735596, + 2.0950584411621094, + -7.93365478515625, + 2.067786931991577, + 4.094852447509766, + 1.673399806022644, + 3.1814424991607666, + 22.49496078491211, + 2.232640027999878, + 2.6427979469299316, + -9.418174743652344, + 1.790976643562317, + 2.3774726390838623, + 2.5836219787597656, + 2.5608203411102295, + 2.287343978881836, + 2.6439085006713867, + 16.859027862548828, + 1.8699607849121094, + -3.6987526416778564, + 2.6861538887023926, + 2.8997464179992676, + 2.689293384552002, + 2.6654043197631836, + 2.3799915313720703, + 2.5603086948394775, + 3.146122694015503, + 2.715951681137085, + 2.889486789703369, + 2.966134548187256, + -4.960191249847412, + 2.6123547554016113, + 1.3074164390563965, + 2.2033026218414307, + 2.2114620208740234, + 4.132844924926758, + 4.893764495849609, + 2.6469600200653076, + 2.654136896133423, + 1.9311997890472412, + 2.881012439727783, + 2.6991193294525146, + ] + value = numpy.array(list_value, dtype=numpy.float32) + tensor = numpy_helper.from_array(value, name="onnx::Conv_502") + + initializers.append(tensor) + + list_value = [ + 0.057212892919778824, + 0.06299274414777756, + -0.018499961122870445, + -0.06501776725053787, + -0.015820641070604324, + 0.024293724447488785, + 0.05624663084745407, + -0.025112055242061615, + 0.043546054512262344, + 0.08439744263887405, + 0.005678815301507711, + 0.0034800865687429905, + 0.030301403254270554, + -0.011669250205159187, + -0.005434689112007618, + -0.1591511219739914, + 0.02324092946946621, + -0.018942436203360558, + 0.025366367772221565, + -0.07414374500513077, + 0.03468436002731323, + -0.003742520697414875, + -0.06651683896780014, + 0.005561002530157566, + 0.04527103528380394, + -0.13710148632526398, + 0.0025444801431149244, + 0.03583350405097008, + 0.015219246037304401, + -0.053635064512491226, + 0.004856681916862726, + -0.07223699986934662, + 0.016770021989941597, + 0.0012010147329419851, + 0.014582094736397266, + -0.005172556731849909, + 0.02009868621826172, + -0.0064261858351528645, + -0.029086023569107056, + 0.001915874076075852, + 0.0008194410474970937, + 0.01620865799486637, + 0.03067426010966301, + -0.0018463254673406482, + 0.05358384922146797, + -0.003966080490499735, + -0.05991416424512863, + -0.06455761194229126, + 0.01634763367474079, + -0.013959774747490883, + 0.03615918383002281, + 0.004434086848050356, + 0.02086004987359047, + -0.004025993403047323, + -0.8869641423225403, + 0.05558132007718086, + 0.024729542434215546, + -0.005809253081679344, + -0.025079259648919106, + 0.04757235199213028, + 0.0023902510292828083, + 0.01522061601281166, + 0.011692625470459461, + 0.023033330217003822, + -0.012664714828133583, + -0.29325294494628906, + -0.006855700630694628, + -0.243958979845047, + 0.0024398649111390114, + -0.060877203941345215, + -0.21996521949768066, + -0.008708474226295948, + -0.06639625877141953, + -0.03170674294233322, + -0.09708897024393082, + 0.013403226621448994, + 0.024766888469457626, + 0.2594103217124939, + -0.02221749909222126, + 0.0662861093878746, + -0.15123076736927032, + -0.010314224287867546, + -0.0029192541260272264, + 0.05985910817980766, + 0.021665453910827637, + 0.003247617743909359, + -0.006802591495215893, + 0.00772367138415575, + 0.0399332195520401, + 0.005198766943067312, + 0.006013805978000164, + -0.04212838411331177, + -0.03166411817073822, + 0.13363900780677795, + 0.006383878644555807, + -0.05536859482526779, + 0.02053261175751686, + 0.015062958002090454, + 0.03352641686797142, + -0.2944328486919403, + 0.019855381920933723, + -0.15567174553871155, + -0.06759943068027496, + 0.07467031478881836, + 0.01674237661063671, + 0.004549413453787565, + -0.0032498433720320463, + -0.1837870180606842, + -0.04725493863224983, + -0.111307792365551, + 0.022237055003643036, + 0.004200428258627653, + 0.00970534235239029, + -0.045657914131879807, + -0.024577995762228966, + 0.0035376595333218575, + 0.008936531841754913, + -0.03904002904891968, + 0.05013228952884674, + -0.011168933473527431, + -0.008444730192422867, + 0.0035155978985130787, + -0.023502476513385773, + 0.005275514908134937, + -0.09448224306106567, + -0.009177467785775661, + -0.010720008052885532, + 0.004110944457352161, + -0.0060218218713998795, + 0.058124978095293045, + -0.0016586220590397716, + 0.15812785923480988, + -0.049118027091026306, + -0.007983109913766384, + -0.04265601187944412, + -0.01627231575548649, + 0.33705562353134155, + 0.01555223111063242, + 0.035853929817676544, + 0.0005046340520493686, + 0.054810188710689545, + -0.08808254450559616, + -0.0013819067971780896, + -0.14938786625862122, + -0.019771935418248177, + 0.004152575507760048, + 0.021979758515954018, + 0.1985529363155365, + -0.07694264501333237, + 0.013187955133616924, + -0.016572976484894753, + -0.03094586730003357, + -0.03673199936747551, + -0.03916170820593834, + -0.003836784977465868, + -0.012262578122317791, + 0.005559554789215326, + 0.1488093137741089, + -0.01842501200735569, + -0.004847189411520958, + -0.02391587756574154, + 0.015824301168322563, + 0.012022596783936024, + 0.06724318116903305, + -0.032682593911886215, + 0.00450896704569459, + -0.0024625889491289854, + 0.00933725107461214, + -0.04473242908716202, + 0.06270455569028854, + -0.02062271721661091, + -0.01071448065340519, + -0.017757099121809006, + 0.01575278490781784, + -0.06489317119121552, + -0.01519051194190979, + 0.0028058059979230165, + 0.00917835533618927, + -0.01291860081255436, + -0.009537308476865292, + 0.041757628321647644, + 0.03203853219747543, + -0.10918509215116501, + -0.007152496371418238, + -0.06777876615524292, + 0.03223242610692978, + 0.01780836284160614, + -0.09791012853384018, + -0.009385241195559502, + 0.013184775598347187, + 0.0031673219054937363, + -0.010640445165336132, + 0.024713385850191116, + -0.026738369837403297, + -0.004191657993942499, + -0.13764967024326324, + -0.003720735665410757, + 0.01737186871469021, + 0.015459887683391571, + 0.033229030668735504, + 0.008042111992835999, + -0.007184108253568411, + 0.008226306177675724, + 0.0031303109135478735, + 0.0406314842402935, + -0.8669105768203735, + 0.02079751342535019, + -0.17030003666877747, + -0.03849703446030617, + 0.034153200685977936, + -0.007219486869871616, + 0.11227627843618393, + -0.2681085467338562, + 0.015872526913881302, + 0.10855260491371155, + -0.008631505072116852, + 0.02556358277797699, + 0.06043418496847153, + -0.012900532223284245, + -0.08834894001483917, + 0.028099440038204193, + -0.05156330019235611, + 0.032628703862428665, + 0.044928934425115585, + 0.006176372990012169, + 0.007333829998970032, + -0.037409231066703796, + -0.046724822372198105, + -0.011172871105372906, + 0.04603327810764313, + 0.03288746625185013, + -0.20848578214645386, + 0.0028185085393488407, + -0.032673876732587814, + 0.061944279819726944, + 0.016787173226475716, + 0.02703898213803768, + -0.0060023171827197075, + 0.06870592385530472, + 0.03154531493782997, + 0.02784041129052639, + 0.007780189625918865, + 0.02033168077468872, + 0.0019289497286081314, + 0.02545374445617199, + 0.04262726008892059, + 0.01301807351410389, + -0.023882156237959862, + 0.027872221544384956, + -0.013518108054995537, + -0.0031075032893568277, + 0.03753834590315819, + 0.0369209349155426, + -0.014378191903233528, + 0.004397932440042496, + -0.030286893248558044, + -0.007679021451622248, + -0.045032769441604614, + 0.032050322741270065, + -0.03373495861887932, + -0.04363032802939415, + 0.034301597625017166, + -0.07021668553352356, + 0.03942524269223213, + -0.11061309278011322, + 0.049139462411403656, + 0.04161922261118889, + -0.01507576834410429, + -0.012748259119689465, + 0.06599434465169907, + 0.007602245546877384, + -0.03973209857940674, + -0.06923151016235352, + 0.026153067126870155, + -0.04221056029200554, + -0.4828230142593384, + 0.03360651433467865, + 0.01847662217915058, + -0.08594681322574615, + 0.04071836546063423, + -0.0035729086957871914, + 0.0049045816995203495, + -0.036198534071445465, + 0.03046257793903351, + 0.013275806792080402, + 0.09266786277294159, + -0.03625647351145744, + -0.059672992676496506, + 0.050213005393743515, + -0.018153885379433632, + -0.0858495831489563, + 0.01621098257601261, + -0.03029749169945717, + 0.02193332649767399, + 0.0422661192715168, + 0.6109512448310852, + -0.01068826112896204, + -0.02184930257499218, + -0.03213764354586601, + -0.03148162364959717, + -0.055331334471702576, + 0.006972005590796471, + -0.00815682765096426, + 0.014874683693051338, + -0.012943249195814133, + -0.03318992629647255, + -0.0010484680533409119, + 0.005414161365479231, + -0.013610370457172394, + 0.008836873807013035, + -0.05890084058046341, + -0.022663919255137444, + -0.018899116665124893, + -0.01037894282490015, + 0.005064660683274269, + 0.08522599190473557, + 0.0075323861092329025, + 0.013720778748393059, + 0.032096460461616516, + -0.008450351655483246, + 0.020377663895487785, + 0.04537765309214592, + 0.014030816033482552, + 0.024340089410543442, + 0.0231801588088274, + -0.10347768664360046, + 0.041163086891174316, + -0.060614243149757385, + -0.09241361171007156, + 0.05831432715058327, + -0.16008608043193817, + -0.04505622759461403, + 0.04866329953074455, + -0.0656094029545784, + 0.09627313911914825, + 0.1153625100851059, + 0.008151216432452202, + 0.03813345730304718, + 0.05990723893046379, + 0.24788673222064972, + 0.06294118613004684, + 0.11761849373579025, + -0.0722033903002739, + -0.013892017304897308, + -0.016778236255049706, + 0.038522012531757355, + -0.015539593063294888, + 0.01263216882944107, + 0.0003969807003159076, + -0.0224238783121109, + -0.005919966846704483, + 0.031987495720386505, + -0.014712700620293617, + 0.03508169203996658, + 0.07568854838609695, + -0.011961974203586578, + 0.027983952313661575, + -0.03512958809733391, + -0.010324078612029552, + -0.2895449995994568, + 0.007338976487517357, + -0.042290836572647095, + -0.1640917807817459, + -0.034807007759809494, + -0.1268443465232849, + 0.18418198823928833, + -0.3867812156677246, + -0.14214494824409485, + 0.001021744217723608, + 0.11288078874349594, + 0.006741920951753855, + -0.006421610247343779, + 0.021150892600417137, + 0.02486848644912243, + 0.002660338068380952, + 0.03732302784919739, + 0.10844919830560684, + -0.032568808645009995, + 0.009477612562477589, + 0.053578171879053116, + -0.07421902567148209, + 0.05660263076424599, + 0.03038308583199978, + 0.049440011382102966, + 0.0395139642059803, + 0.0217339675873518, + 0.028231965377926826, + 0.1661153882741928, + -0.02168717049062252, + 0.055143170058727264, + -0.14159196615219116, + 0.05894732475280762, + 0.006888065952807665, + -0.06988262385129929, + 0.017527412623167038, + -0.007171930745244026, + -0.00448343763127923, + 0.02932717651128769, + -0.00652179354801774, + -0.002897858154028654, + 0.020487705245614052, + -0.027063967660069466, + -0.02539752423763275, + -0.1066114604473114, + -0.10011029988527298, + -0.03331710025668144, + -0.003807300003245473, + -0.010441976599395275, + -0.005605363752692938, + 0.09679440408945084, + 0.020033519715070724, + -0.010188378393650055, + -0.030630890280008316, + -0.00955540407449007, + 0.02825581096112728, + -0.4307324290275574, + 0.012557203881442547, + 0.043258048593997955, + 0.09386534243822098, + -0.009555542841553688, + 0.05304868891835213, + 0.014706632122397423, + -0.012911850586533546, + 0.0981304720044136, + -0.010722141712903976, + -0.027317194268107414, + 0.0893903523683548, + -0.19983792304992676, + -0.15778200328350067, + -0.1012115329504013, + -0.3758164644241333, + -0.05782865360379219, + -0.01230492815375328, + -0.37126046419143677, + -0.01596723683178425, + 0.0020407456904649734, + -0.017498979344964027, + 0.005369496997445822, + -0.023121315985918045, + 0.022279681637883186, + -0.006232256535440683, + 0.05115891620516777, + 0.006679570768028498, + 0.0026316209696233273, + 0.04291496425867081, + 0.04381528124213219, + -0.05994122102856636, + 0.007081915624439716, + -0.04571640491485596, + 0.07592425495386124, + -0.00836833007633686, + 0.008123279549181461, + -0.008003163151443005, + -0.003938044421374798, + 0.005643180105835199, + 0.016194086521863937, + -0.004063089843839407, + 0.012334472499787807, + 0.017072021961212158, + 0.005761854816228151, + 0.004702428821474314, + 0.005736868362873793, + 0.0017962371930480003, + 0.059996701776981354, + 0.19533602893352509, + 0.02649352326989174, + -0.06493135541677475, + -0.05955052375793457, + 0.015692468732595444, + -0.10623155534267426, + 0.07290898263454437, + 0.036108434200286865, + -0.01248949021100998, + 0.16444285213947296, + -0.005899128969758749, + 0.07875277101993561, + 0.0014204353792592883, + 0.03381470963358879, + -0.09680792689323425, + 0.002102318685501814, + 0.026962973177433014, + 0.031665392220020294, + -0.18168538808822632, + 0.11163855344057083, + -0.5409999489784241, + 0.07833191007375717, + -0.005324948113411665, + 0.0267564058303833, + 0.02250477857887745, + 0.03249068558216095, + -0.18441715836524963, + -0.006447427906095982, + 0.037927329540252686, + 0.0005173985846340656, + -0.02617005631327629, + 0.05929232016205788, + -0.028510913252830505, + 0.05447050556540489, + 0.012390155345201492, + 0.00046797769027762115, + -0.008598590269684792, + -0.17247197031974792, + -0.02855759859085083, + 0.033968932926654816, + -0.09011702984571457, + 0.05276056379079819, + 0.03299655020236969, + -0.005699596833437681, + -0.1954648792743683, + 0.011109501123428345, + -0.0013570536393672228, + -0.6543989181518555, + 0.009102803654968739, + 0.0407538004219532, + 0.04312055557966232, + 0.027609223499894142, + -0.035538043826818466, + 0.027167823165655136, + -0.024043193086981773, + 0.0047575319185853004, + -0.006788836792111397, + 0.025714389979839325, + 0.007848678156733513, + -0.07680192589759827, + 0.009700766764581203, + -0.0097329281270504, + 0.00586724653840065, + 0.022815868258476257, + -0.023448282852768898, + -0.05608998239040375, + 0.10786863416433334, + -0.02803603559732437, + 0.012898198328912258, + -0.009270391426980495, + -0.021972229704260826, + 0.26533082127571106, + -0.01021308358758688, + -0.01972626894712448, + 0.062940314412117, + 0.022569671273231506, + 0.027042347937822342, + -0.05669092759490013, + -0.01200617104768753, + -0.006279367487877607, + -0.009608528576791286, + -0.013600943610072136, + -0.02187415212392807, + 0.0351138636469841, + 0.006282923277467489, + -0.011123511008918285, + -0.009205769747495651, + 0.001010146806947887, + -0.4796978235244751, + -0.0030205894727259874, + -0.011987377889454365, + -0.027548225596547127, + 0.009372347965836525, + -0.005388603545725346, + -0.006444129627197981, + -0.02501147985458374, + 0.027465635910630226, + 0.027784524485468864, + 0.006878893356770277, + -0.027763860300183296, + -0.0047700353898108006, + -0.018965192139148712, + 0.027898501604795456, + 0.022454144433140755, + 0.02973407506942749, + 0.03505602851510048, + 0.04003170132637024, + -0.004336829297244549, + -0.01998550072312355, + -0.06097743660211563, + -0.07844759523868561, + 0.0013787010684609413, + 0.0066132270731031895, + -0.03124997951090336, + 0.0313432514667511, + 0.047656893730163574, + 0.06175797060132027, + -0.02077358029782772, + -0.004535601008683443, + -0.10219905525445938, + -0.07125344127416611, + -0.06927482783794403, + -0.04813461750745773, + -0.02618095651268959, + -0.01255929097533226, + -0.009180150926113129, + -0.005838831886649132, + 0.09108023345470428, + -0.032710760831832886, + 0.03091445378959179, + -0.01955563761293888, + 0.0959300771355629, + -0.09353741258382797, + -0.0761636272072792, + -0.023445438593626022, + -0.012328366748988628, + 0.05850536748766899, + -0.052494827657938004, + 0.0025638933293521404, + -0.017152179032564163, + -0.004435579292476177, + 0.12312240898609161, + -0.007241012528538704, + 0.09605048596858978, + 0.03355967625975609, + -0.015987426042556763, + -0.03470349311828613, + -0.02499505691230297, + -0.015004142187535763, + -0.018609771504998207, + -0.06654462963342667, + 0.013861652463674545, + -0.005973289255052805, + -0.04734775796532631, + 0.08755116909742355, + 0.03012942522764206, + 0.07887610793113708, + -0.01827712170779705, + 0.10793066769838333, + 0.10793614387512207, + -0.01075535174459219, + 0.03439560532569885, + 0.011567444540560246, + 0.0016386889619752765, + -0.031207261607050896, + -0.01707504875957966, + 0.20471863448619843, + 0.0025428179651498795, + 0.004082779865711927, + -0.012389302253723145, + 0.0400562584400177, + -0.21075034141540527, + 0.012872264720499516, + -0.01639414019882679, + 0.016652485355734825, + 0.0016037120949476957, + -0.006540367379784584, + -0.0068405005149543285, + -0.2484254390001297, + 0.0008089764742180705, + -0.022340824827551842, + -0.005441636312752962, + 0.002882100408896804, + 0.008654038421809673, + 0.07159754633903503, + -0.02537086047232151, + 0.011997461318969727, + -0.49913132190704346, + -0.02300887741148472, + 0.044442202895879745, + 0.001787978457286954, + 0.010291379876434803, + 0.009601960889995098, + -0.5312613248825073, + -0.014247804880142212, + 0.06685849279165268, + 0.035772595554590225, + 0.03432310372591019, + 0.03151272237300873, + -0.10318460315465927, + -0.030476456508040428, + -0.004469831008464098, + -0.16645164787769318, + -0.021104637533426285, + 0.013934006914496422, + -0.011767406016588211, + 0.008054615929722786, + 0.06089277192950249, + 0.0003409573109820485, + -0.0053401123732328415, + 0.05970478057861328, + -0.004363172687590122, + 0.014423285610973835, + -0.002795026171952486, + -0.019875092431902885, + -0.07540513575077057, + -0.09043378382921219, + 0.00750827556475997, + -0.045314721763134, + -0.00724808732047677, + 0.005193864461034536, + -0.020468784496188164, + -0.01098695583641529, + -0.0003122477210126817, + -0.007263806648552418, + -0.03325646370649338, + 0.021689830347895622, + -0.13272541761398315, + 0.02332465350627899, + -0.019292252138257027, + 0.05533658340573311, + -0.018616480752825737, + -0.015228793025016785, + -0.28432801365852356, + -0.29721561074256897, + 0.04648810625076294, + -0.014750649221241474, + -0.15370936691761017, + -0.1497083604335785, + 0.013243601657450199, + 0.042343802750110626, + -0.017519792541861534, + -0.0161418616771698, + 0.00807454064488411, + -0.023562468588352203, + -0.0315413773059845, + 0.03386805206537247, + 0.2854529917240143, + 0.0191020630300045, + -0.49126777052879333, + 0.052687134593725204, + -0.023298051208257675, + -0.009119837544858456, + 0.05149759724736214, + -0.8527837991714478, + 0.08062390983104706, + 0.057379938662052155, + -0.020724931731820107, + -0.006624895613640547, + 0.05322050303220749, + 0.017887847498059273, + 0.04229281470179558, + 0.04171830415725708, + 0.029683062806725502, + -0.00028416322311386466, + 0.1112222746014595, + -0.0448714978992939, + -0.005255761090666056, + 0.017773712053894997, + -0.0016064767260104418, + -0.013840594328939915, + -0.00398495327681303, + -4.32919041486457e-05, + 0.040796443819999695, + 0.018185198307037354, + -0.018671950325369835, + 0.0028256692457944155, + -0.020582057535648346, + 0.05567716807126999, + -0.056062404066324234, + 0.01614757999777794, + -0.0029299987945705652, + 0.048686008900403976, + 0.04299888014793396, + 0.12249592691659927, + 0.01469603180885315, + -0.1254546344280243, + -0.18532024323940277, + -0.003263876074925065, + 0.014804725535213947, + 0.004450956825166941, + -0.013681051321327686, + -0.0030781759414821863, + -0.03433656692504883, + -0.0035507124848663807, + 0.1600082814693451, + -0.028547707945108414, + -0.00989136379212141, + -0.012126478366553783, + -0.12963305413722992, + 0.008547360077500343, + 0.017959514632821083, + -0.012571084313094616, + 0.0008666724897921085, + -0.010519342496991158, + -0.009684977121651173, + -0.04285729303956032, + 0.015031769871711731, + -0.030043724924325943, + 0.018907636404037476, + 0.08019450306892395, + -0.04836742579936981, + 0.01025464478880167, + -0.004908542148768902, + -0.10327022522687912, + -0.10163667798042297, + -0.03403499722480774, + -0.019678063690662384, + -0.043049123138189316, + 0.0384567566215992, + -0.05596519634127617, + -0.09381429851055145, + -0.18688108026981354, + -0.09762943536043167, + -0.03164997324347496, + -0.006416287273168564, + 0.07003920525312424, + -0.016646990552544594, + -0.025972194969654083, + -0.028768088668584824, + -0.06332779675722122, + 0.045144014060497284, + -0.03735211119055748, + -0.010442189872264862, + 0.10948455333709717, + 0.14629514515399933, + -0.023416690528392792, + -0.01347778458148241, + 0.020830679684877396, + 0.0003131759003736079, + 0.007049075793474913, + 0.06547018885612488, + 0.03152740001678467, + 0.08380027115345001, + 0.03185325488448143, + -0.015359007753431797, + 0.08864206075668335, + 0.032676901668310165, + -0.002908645663410425, + 0.053111132234334946, + 0.0026159954722970724, + -0.05177146941423416, + -0.033048152923583984, + -0.0020293137058615685, + -0.07363513857126236, + -0.17662747204303741, + 0.004798125941306353, + 0.07139395922422409, + 0.019802849739789963, + 0.009199771098792553, + -0.009043877013027668, + -0.07681646943092346, + -0.06748555600643158, + 0.05094710737466812, + 0.0014789587585255504, + -0.0166088305413723, + -0.27988284826278687, + 0.03634800389409065, + 0.05322619527578354, + -0.15566207468509674, + -0.019964642822742462, + -0.010204506106674671, + -0.011832086369395256, + -0.0680927112698555, + -0.05793820694088936, + 0.0020100779365748167, + -0.24647225439548492, + 0.04904041066765785, + -0.05589786171913147, + -0.030167482793331146, + 0.023974033072590828, + -0.22719347476959229, + 0.019620347768068314, + -0.18078163266181946, + -0.11321499198675156, + -0.023790234699845314, + -0.1266157031059265, + 0.01117659267038107, + 0.13824795186519623, + -0.024211348965764046, + -0.0548308864235878, + 0.04849318787455559, + -0.0016174454940482974, + -0.01826266385614872, + 0.006709347013384104, + -0.350631982088089, + 0.03139018639922142, + 0.021502504125237465, + -0.12596893310546875, + 0.04311670735478401, + -0.005905786994844675, + -0.0807335153222084, + -0.07214773446321487, + -0.2054852843284607, + -0.04526854678988457, + -0.09145382046699524, + 0.002603817731142044, + -0.01951524056494236, + -0.0028278473764657974, + -0.03270411863923073, + -0.0003385065938346088, + -0.019816655665636063, + -0.003430107608437538, + 0.010664679110050201, + 0.030127109959721565, + 0.02611778862774372, + 0.030213139951229095, + 0.04682943969964981, + 0.010338326916098595, + -0.02618880569934845, + 0.014982170425355434, + -0.06979402899742126, + 0.06403722614049911, + 0.025545112788677216, + -0.11981001496315002, + 0.004320457112044096, + 0.008849565871059895, + 0.07450827211141586, + -0.04322020336985588, + -0.07648278027772903, + 0.009221173822879791, + -0.12771189212799072, + 0.027474528178572655, + -0.1637975573539734, + -0.022587651386857033, + 0.0713210329413414, + -0.09652210026979446, + -0.04942077025771141, + -0.08977267891168594, + -0.004629603121429682, + -0.09891843795776367, + 0.0004028059483971447, + 0.12999524176120758, + 0.009417874738574028, + -0.012465995736420155, + 0.09959464520215988, + 0.012048770673573017, + 0.00529639283195138, + -0.1231047734618187, + -0.010156300850212574, + -0.0067022680304944515, + 0.09231371432542801, + 0.1372271031141281, + 0.01140755694359541, + -0.014376018196344376, + 0.009014246053993702, + -0.0558021254837513, + 0.009297777898609638, + -0.023461824283003807, + 0.12312523275613785, + 0.0013492326252162457, + -0.10130659490823746, + 0.07867099344730377, + -0.04363301396369934, + -0.05203291028738022, + 0.010715829208493233, + 0.2679101228713989, + 0.047242000699043274, + 0.009700302965939045, + -0.004188477993011475, + 0.04595324397087097, + -0.10256988555192947, + 0.013266253285109997, + 0.13415516912937164, + -0.06461263447999954, + -0.04262775555253029, + 0.014638054184615612, + -0.020396970212459564, + 0.016008291393518448, + 0.012964261695742607, + 0.030219901353120804, + -0.03906702250242233, + -0.009459082037210464, + -0.006880247965455055, + 0.009383107535541058, + 0.0591101311147213, + -0.049882922321558, + -0.014105924405157566, + -0.04896679148077965, + 0.021726086735725403, + -0.013863577507436275, + -0.05801064148545265, + -0.031143831089138985, + 0.0010298469569534063, + -0.03104572743177414, + 0.1193046048283577, + 0.00880056619644165, + -0.01678626798093319, + 0.0014990485506132245, + -0.001967367948964238, + -0.0053575835190713406, + -0.006879259832203388, + -0.008937212638556957, + 0.014141763560473919, + 0.00687083275988698, + -0.0012949275551363826, + 0.017160816118121147, + -0.035110652446746826, + -0.00976842176169157, + 0.026605995371937752, + 0.004003277514129877, + 0.010927689261734486, + 0.002173327375203371, + -0.05133439600467682, + -0.04658171907067299, + 0.03023359179496765, + -0.015038624405860901, + 0.016580749303102493, + 0.02393144741654396, + 0.004817661829292774, + -0.008468102663755417, + 0.017239807173609734, + 0.019924553111195564, + 0.02557404898107052, + 0.01985766738653183, + -0.01881517469882965, + -0.14637643098831177, + -0.005403783638030291, + -0.013156545348465443, + -0.3882855176925659, + 0.01537711638957262, + 0.005061861593276262, + 0.018044542521238327, + 0.00010373388067819178, + -0.01769324019551277, + -0.020439250394701958, + 0.01761222817003727, + 0.017716309055685997, + -0.01828574948012829, + 0.0059916484169662, + 0.006117791403084993, + -0.0025541253853589296, + 0.01598154753446579, + 0.0015296537894755602, + 0.006711189169436693, + -0.005831963382661343, + 0.024547481909394264, + 0.011665170080959797, + 0.013990279287099838, + -0.009193074889481068, + -0.0014407691778615117, + 0.0025373499374836683, + -0.001535113900899887, + 0.022016262635588646, + 0.002165747107937932, + -0.00010288839985150844, + -0.01185672264546156, + 0.3959958255290985, + -0.06701132655143738, + 0.024550342932343483, + -0.007259713020175695, + 0.00011224728223169222, + 0.08959072828292847, + 0.006745494436472654, + -0.007461291737854481, + -0.0010788652580231428, + -0.003997487016022205, + 0.0023250498343259096, + 0.005845727398991585, + 0.002441686810925603, + 0.0010628585005179048, + 0.004687050357460976, + 0.03825820982456207, + 0.0027951127849519253, + 0.004356732591986656, + 0.0036379920784384012, + -0.00048690394032746553, + -0.31681910157203674, + 0.01621195860207081, + 0.009373913519084454, + -0.005099120549857616, + 0.004866141825914383, + 0.008112045004963875, + -0.009933174587786198, + -0.006929770577698946, + 0.005561198107898235, + -0.2225065976381302, + -0.00019208311277907342, + -0.003284667618572712, + 0.010527989827096462, + -0.010160842910408974, + -0.008410060778260231, + 0.004605174530297518, + 0.01542133092880249, + 0.013958578929305077, + 0.0021779180970042944, + 0.002810562262311578, + 0.001369283301755786, + -0.0003347232413943857, + 0.013902815990149975, + -0.0022218015510588884, + 0.00024955783737823367, + -0.0019350153161212802, + 0.0025213193148374557, + -0.0054915109649300575, + -0.00011564489977899939, + -0.0037644850090146065, + -0.002863431815057993, + -0.0025196163915097713, + 0.02352992817759514, + 0.00354134407825768, + -0.010700036771595478, + -0.03428381308913231, + 0.008170859888195992, + 0.005420713219791651, + -0.0013479178305715322, + 0.0015741022070869803, + -0.18286381661891937, + 0.03189067915081978, + 0.0014371845172718167, + -4.885893940809183e-05, + -0.004666821099817753, + -0.026595929637551308, + -0.0064376350492239, + 0.01583540253341198, + -0.085715651512146, + -0.00916224904358387, + -0.3605174124240875, + 0.019973354414105415, + 0.05533794313669205, + 0.053907446563243866, + 0.030877795070409775, + -0.919844925403595, + 8.968543988885358e-05, + -0.02068270742893219, + 0.012602192349731922, + 0.03245612978935242, + 0.06622699648141861, + 0.00882122665643692, + -0.03616628423333168, + -0.02428283728659153, + 0.003318701172247529, + -0.0007259293342940509, + -0.026197656989097595, + -0.059503961354494095, + 0.029495801776647568, + -0.006955073680728674, + -0.01926456019282341, + 0.009927013888955116, + 0.059641581028699875, + 0.0016886347439140081, + -0.029346982017159462, + 0.01948450319468975, + -0.04397860914468765, + 0.025248751044273376, + 0.04597266763448715, + 0.009454794228076935, + -0.018872544169425964, + -0.039650529623031616, + 0.026324709877371788, + -0.01808176562190056, + 0.028935831040143967, + 0.009501701220870018, + -0.05183069407939911, + -0.005787428934127092, + -0.021436212584376335, + 0.029735956341028214, + 0.0350160151720047, + 0.033825185149908066, + 0.03185566887259483, + 0.018431033939123154, + 0.02450188808143139, + 0.03271135315299034, + -0.0027792940381914377, + -0.0004625302099157125, + 0.01268392987549305, + 0.045023106038570404, + 0.05562014505267143, + 0.029052015393972397, + -0.002513203304260969, + -0.08349838852882385, + 7.017837560852058e-06, + -0.0014392733573913574, + 0.016982918605208397, + 0.016358936205506325, + -0.024013325572013855, + -0.004375616554170847, + -0.03734249249100685, + 0.04336351156234741, + 0.07323610782623291, + -0.0243068914860487, + 0.009403819218277931, + 0.02663031965494156, + 0.01930687017738819, + 0.02175578847527504, + 0.01639295555651188, + 0.024892140179872513, + 0.031219134107232094, + 0.02986173704266548, + -0.002100786194205284, + 0.05054357647895813, + 0.04015854373574257, + 0.0048207067884504795, + -0.03244275599718094, + 0.027246609330177307, + 0.00409608893096447, + -0.0054193479008972645, + 0.07014931738376617, + 0.009954879060387611, + 0.022472694516181946, + -0.47738370299339294, + -0.019097158685326576, + 0.028984038159251213, + -0.042564358562231064, + -0.006040808744728565, + 0.04094231128692627, + -0.007740774191915989, + -0.07854597270488739, + 0.003920051269233227, + -0.050799619406461716, + 0.023691626265645027, + 0.019952887669205666, + 0.00716764759272337, + -0.0046928380616009235, + 0.00041822553612291813, + 0.006359069608151913, + 0.017860781401395798, + -0.22999149560928345, + -0.02180831879377365, + -0.024055887013673782, + -0.0226126741617918, + -0.01795077696442604, + 0.015591473318636417, + -0.004053472075611353, + 0.016760380938649178, + 0.03378744795918465, + -0.0027090508956462145, + 0.00999806821346283, + 0.019252799451351166, + 0.0027550198137760162, + 0.03454355522990227, + -0.0295003242790699, + -0.007663591764867306, + 0.061172280460596085, + 0.049142658710479736, + -0.00858291145414114, + -0.0035321018658578396, + -0.7689260244369507, + 0.0004916944890283048, + 0.02915046364068985, + 0.017000442370772362, + -0.003298018593341112, + -0.0405484102666378, + 0.021160880103707314, + 0.0013289587805047631, + -0.07510386407375336, + 0.03890690207481384, + 0.03729970380663872, + -0.04906352981925011, + -0.10020274668931961, + 0.01506283599883318, + -0.053726132959127426, + 0.016631007194519043, + 0.03425036743283272, + 0.03358260169625282, + -0.023937245830893517, + -0.13656578958034515, + -0.13947314023971558, + 0.012915699742734432, + 0.02431132085621357, + -0.03089652583003044, + 0.1382707953453064, + 0.056695129722356796, + -0.09263960272073746, + 0.10406216233968735, + 0.02619105577468872, + -0.01678614132106304, + -0.16045455634593964, + 8.974489173851907e-05, + -0.03521093726158142, + -0.028908027336001396, + 0.21234789490699768, + -0.02046572044491768, + -0.09703273326158524, + 0.05248226970434189, + 0.011973158456385136, + 0.004557646345347166, + -0.018632734194397926, + -0.1649131029844284, + -0.00682018743827939, + -0.12712189555168152, + 0.10513507574796677, + 0.020745709538459778, + 0.02996259182691574, + -0.15409024059772491, + -0.08719073981046677, + -0.14634187519550323, + -0.16255779564380646, + -0.15963757038116455, + -0.1324772834777832, + -0.022830091416835785, + -0.06426219642162323, + -0.025459224358201027, + 0.00281702633947134, + 0.03255268186330795, + -0.05778049677610397, + -0.30381152033805847, + -0.06582051515579224, + -0.033722274005413055, + 0.014956191182136536, + 0.004153797868639231, + 0.2391217201948166, + -0.0311420951038599, + 0.001518488978035748, + 0.019769812002778053, + -0.056324463337659836, + -0.006009253207594156, + -0.21367721259593964, + -0.0481688529253006, + 0.22422266006469727, + 0.0402204655110836, + 0.1432792693376541, + 0.14159953594207764, + -0.0025862890761345625, + -0.028965365141630173, + 0.011978867463767529, + 0.161293163895607, + 0.028642605990171432, + -0.008417634293437004, + -0.10145614296197891, + 0.08381767570972443, + 0.05199432373046875, + 0.18680602312088013, + -0.023287687450647354, + 0.03601476550102234, + 0.03738229721784592, + 0.19291405379772186, + 0.03553088754415512, + 0.05483124405145645, + 0.09577616304159164, + -0.004635817836970091, + 0.052481625229120255, + -0.042084019631147385, + -0.2629147469997406, + -0.006157668773084879, + -0.0401761569082737, + 0.02154349908232689, + -0.056558139622211456, + -0.003753019031137228, + 0.01922912523150444, + 0.1291409730911255, + -0.21358416974544525, + 0.004696246236562729, + 0.13787509500980377, + -0.07022479176521301, + -0.06828727573156357, + 0.09193858504295349, + -0.06863763928413391, + -0.05677935853600502, + -0.030970478430390358, + -0.10181070864200592, + -0.1247706487774849, + 0.014181962236762047, + -0.09259836375713348, + -0.03174220770597458, + -0.014812505804002285, + -0.024658311158418655, + -0.04815720021724701, + -0.01683010160923004, + 0.015726473182439804, + 0.002938281511887908, + -0.1586887538433075, + -0.29276973009109497, + -0.029981529340147972, + -0.046828676015138626, + -0.04909103736281395, + 0.06043976545333862, + 0.03698069602251053, + -0.04807118698954582, + 0.0943484902381897, + 0.01930702105164528, + 0.06498143821954727, + 0.0381690077483654, + -0.19611406326293945, + 0.006944946013391018, + 0.06454038619995117, + -0.19779883325099945, + 0.04966692253947258, + 0.046355295926332474, + 0.0590626522898674, + -0.24392037093639374, + -0.0018132536206394434, + 0.010944955050945282, + -0.014556891284883022, + 0.051466893404722214, + -0.0059846509248018265, + -0.06719732284545898, + 0.030604040250182152, + 0.051190104335546494, + -0.053196243941783905, + -0.06912374496459961, + -0.06263922154903412, + 0.05626852437853813, + 0.013047950342297554, + -0.005828890949487686, + 0.056055404245853424, + 0.007044378202408552, + 0.030499491840600967, + -0.035373322665691376, + 0.030934391543269157, + 0.04358363524079323, + 0.001537138712592423, + 0.005963161122053862, + -0.005889860913157463, + 0.053225863724946976, + 0.052091702818870544, + -0.02871675044298172, + 0.05662619322538376, + -0.4585985839366913, + 0.06490323692560196, + 0.02542230300605297, + 0.017592567950487137, + 0.05066920816898346, + -0.20954127609729767, + -0.06689731031656265, + -0.3632309138774872, + -0.03407476842403412, + 0.04976007342338562, + 0.03856723755598068, + 0.009329214692115784, + -0.10107281804084778, + 0.007077769376337528, + -0.005482642911374569, + 0.04388934373855591, + 0.03984231874346733, + 0.005358297843486071, + 0.05032944679260254, + 0.007170544005930424, + 0.017318176105618477, + -0.03577208146452904, + -0.02195456624031067, + 0.014414021745324135, + -0.008203372359275818, + 0.04585091397166252, + -0.012298643589019775, + 0.03959968313574791, + -0.06015963852405548, + -0.1360240876674652, + -0.07704123109579086, + -0.0842466950416565, + -0.11261942237615585, + 0.0433686338365078, + -0.1059969812631607, + 0.014813154004514217, + 0.04216694459319115, + 0.10441470146179199, + 0.04579426348209381, + 0.026033954694867134, + 0.08725529909133911, + -0.14662955701351166, + -0.0726592168211937, + 0.1293957382440567, + 0.013497715815901756, + -0.01318936888128519, + -0.05188713222742081, + 0.08793413639068604, + 0.1094818189740181, + 0.07991892844438553, + 0.03549068048596382, + -0.04469897970557213, + -0.10442564636468887, + 0.13456915318965912, + 0.01154977548867464, + -0.05959299951791763, + 0.01768219843506813, + 0.0179652888327837, + -0.010112428106367588, + 0.020603090524673462, + -0.7144030928611755, + 0.20126283168792725, + 0.058172807097435, + -0.10543914139270782, + 0.07461538910865784, + -0.1744592934846878, + 0.055722273886203766, + -0.046595826745033264, + 0.06237049773335457, + 0.05800141766667366, + 0.04118870943784714, + 0.002582935383543372, + 0.010623090900480747, + -0.0439014658331871, + 0.044685740023851395, + -0.017063472419977188, + -0.0173367727547884, + -0.04761765897274017, + 0.06136244907975197, + 0.08495236933231354, + 0.24923592805862427, + -0.061080869287252426, + 0.15922360122203827, + -0.09322690963745117, + -0.09617402404546738, + 0.0029533954802900553, + 0.12630371749401093, + 0.0011397749185562134, + 0.0005059551913291216, + -0.060922350734472275, + -0.16446451842784882, + 0.057099178433418274, + 0.03073902614414692, + -0.031064951792359352, + 0.012277435511350632, + 0.020447896793484688, + 0.06010727211833, + 0.07065457105636597, + 0.026963504031300545, + 0.010798406787216663, + -0.02631279267370701, + 0.02046871930360794, + -0.004800989292562008, + -0.03282550349831581, + 0.053904879838228226, + -0.03294985368847847, + -0.4204113185405731, + 0.028552187606692314, + 0.023685462772846222, + 0.0017703581834211946, + 0.02868991158902645, + -0.3585520088672638, + -0.011516556143760681, + -0.00248165475204587, + 0.011379038915038109, + 0.0459531806409359, + 0.015357235446572304, + 0.05573337897658348, + 0.06516549736261368, + 0.02981666848063469, + 0.05498211458325386, + 0.028714550659060478, + -0.005899528972804546, + 0.008476868271827698, + 0.11328839510679245, + 0.020578190684318542, + -0.15382742881774902, + 0.015724696218967438, + -0.08402770012617111, + 0.060314107686281204, + 0.032343748956918716, + 0.014438764192163944, + -0.13614842295646667, + -0.0017508765449747443, + 0.09998518973588943, + -0.06364594399929047, + 0.049632295966148376, + -0.11922458559274673, + -0.08834195137023926, + 0.019541991874575615, + 0.06320779770612717, + 0.017419861629605293, + -0.0028468866366893053, + -0.14753428101539612, + 0.02623703144490719, + -0.011462770402431488, + 0.06676206737756729, + -0.014891563914716244, + -0.002118025440722704, + 0.02519390918314457, + -0.29581141471862793, + 0.0264339130371809, + 0.04027356952428818, + 0.00412194337695837, + 0.03778498247265816, + -0.012331741861999035, + 0.15336745977401733, + -0.034510836005210876, + 0.0319819413125515, + 0.01916184462606907, + 0.04952343553304672, + -0.026733938604593277, + -0.014996573328971863, + 0.0010714810341596603, + 0.01959756202995777, + -0.0392388179898262, + -0.0052064210176467896, + -0.05015777423977852, + -0.0002977418771479279, + -0.04029487073421478, + -0.012846150435507298, + -0.09198840707540512, + 0.0118671590462327, + -0.06176264211535454, + 0.006427878048270941, + 0.04043034091591835, + -0.017270859330892563, + -0.012422707863152027, + 0.01713552325963974, + -0.026697810739278793, + 0.2446632832288742, + -0.020500628277659416, + -0.0012782106641680002, + -0.13429665565490723, + 0.07528743892908096, + -0.002225265372544527, + 0.06695574522018433, + 0.0017388156848028302, + -0.0629071593284607, + -0.05081196129322052, + 0.042025983333587646, + 0.029097404330968857, + 0.07048555463552475, + -0.11881273239850998, + 0.012633765116333961, + -0.06181430071592331, + 0.038810230791568756, + 0.05186169967055321, + 0.03248963877558708, + 0.07868267595767975, + 0.024977494031190872, + 0.023991582915186882, + 0.0023529180325567722, + 0.07197123020887375, + 0.02653665468096733, + 0.058702051639556885, + 0.015001803636550903, + 0.043739400804042816, + -0.07251746207475662, + 0.045659150928258896, + -0.02111324854195118, + 0.26666632294654846, + 0.1975221484899521, + -0.031074335798621178, + 0.029075143858790398, + 0.013020229525864124, + 0.015244663693010807, + 0.01387549377977848, + -0.025354426354169846, + 0.06151636317372322, + -0.034430794417858124, + 0.00752665288746357, + 0.1678706705570221, + -0.016560610383749008, + 0.0421285480260849, + -0.02527586743235588, + -0.02166694961488247, + -0.034658536314964294, + 0.036866605281829834, + -0.036233626306056976, + 0.02042747661471367, + 0.028099242597818375, + 0.020503878593444824, + 0.022789381444454193, + 0.08666791766881943, + -0.06426636874675751, + -0.043599683791399, + 0.1136128157377243, + 0.020200412720441818, + -0.003839759388938546, + -0.06010120362043381, + -0.02218424715101719, + 0.09008956700563431, + 0.008711264468729496, + -0.04874516651034355, + -0.011533043347299099, + -0.036206502467393875, + -0.006006627343595028, + -0.0350450798869133, + 0.005623341538012028, + 0.09562186151742935, + -0.03952183946967125, + -0.013931595720350742, + -0.020029470324516296, + 0.0022144403774291277, + -0.020198611542582512, + 0.012238736264407635, + 0.054415784776210785, + -0.024457741528749466, + -0.01174110360443592, + 0.031656913459300995, + 0.060322560369968414, + 0.01573050767183304, + 0.03361794352531433, + 0.022875478491187096, + 0.036340806633234024, + -0.02932620421051979, + 0.0224352665245533, + -0.013475337065756321, + -0.030774995684623718, + 0.013921404257416725, + -0.01229875348508358, + -0.07986237108707428, + -0.007543445099145174, + 0.05208213999867439, + -0.04440496116876602, + -0.029659371823072433, + -0.029070377349853516, + 0.07376870512962341, + -0.07208643853664398, + -0.05429431423544884, + -0.007887271232903004, + 0.011400371789932251, + 0.014227204024791718, + 0.01763899251818657, + -0.0426466204226017, + 0.0024213625583797693, + 0.02564665488898754, + 0.0020850151777267456, + 0.027386819943785667, + 0.12722602486610413, + -0.060991525650024414, + -0.009061425924301147, + 0.014208497479557991, + -0.006956137716770172, + 0.09096626192331314, + 0.0037735258229076862, + -0.8347064852714539, + -0.2857951521873474, + 0.0011818337952718139, + 0.0341162234544754, + -0.04230167716741562, + 0.05230262130498886, + 0.08486262708902359, + -0.34235459566116333, + -0.02393503487110138, + 0.02718495950102806, + 0.050966840237379074, + 0.024611525237560272, + -0.004936584271490574, + -0.036420952528715134, + -0.009803534485399723, + 0.05421328917145729, + 0.008357672952115536, + 0.020987343043088913, + -0.007292840629816055, + 0.018060531467199326, + 0.06739793717861176, + 0.06161382421851158, + 0.000842935056425631, + -0.007857701741158962, + 0.023870037868618965, + -0.009690430946648121, + -0.04231289029121399, + -0.22531479597091675, + 0.034284885972738266, + 0.07360551506280899, + 0.0421777106821537, + 0.000788167177233845, + -0.3953339457511902, + -0.042627450078725815, + -0.02774403616786003, + 0.02647743932902813, + -0.01561375055462122, + 0.04745408892631531, + 0.021774733439087868, + 0.006606150884181261, + 0.03879173845052719, + 0.06500626355409622, + 0.044954728335142136, + 0.01523532159626484, + 0.04741065576672554, + -0.13645507395267487, + 0.0038059696089476347, + -0.012993253767490387, + -0.004529603291302919, + 0.03268986567854881, + -0.025349941104650497, + -0.02268051542341709, + -0.0001516443444415927, + -0.010289257392287254, + -0.0010476588504388928, + -0.0690254345536232, + 0.04298266023397446, + -0.05470968782901764, + 0.04369102790951729, + -0.007372597698122263, + 0.027607066556811333, + 0.0009343988494947553, + -0.09573916345834732, + 0.04389296472072601, + -0.01522558368742466, + -0.03138086944818497, + 0.04511113464832306, + -0.0342172235250473, + -0.00033129166695289314, + -0.037289440631866455, + 0.055575959384441376, + 0.01849759928882122, + 0.03041103295981884, + -0.01965116336941719, + 0.07604960352182388, + -0.0399625338613987, + -0.008190250024199486, + -0.015386211685836315, + -0.04315667226910591, + 0.0023679479490965605, + 0.018971435725688934, + -0.005599244497716427, + -0.029607947915792465, + 0.07574024051427841, + -0.013816094025969505, + 0.04464992880821228, + 0.00032806122908368707, + 0.06071484833955765, + 0.04261377081274986, + 0.012208743952214718, + 0.0801805928349495, + 0.02875029854476452, + -0.0662921741604805, + 0.015754999592900276, + 0.05831082537770271, + 0.03810921683907509, + 0.05483977496623993, + -0.019509335979819298, + 0.0032034649048000574, + 0.011807492934167385, + -0.01916244812309742, + 0.022101666778326035, + -0.0366031751036644, + 0.10915965586900711, + 0.030322788283228874, + -0.028386037796735764, + -0.05443429946899414, + -0.02489445172250271, + 0.0892239362001419, + -0.05427740886807442, + -0.034238025546073914, + -0.04136161506175995, + -0.041148390620946884, + 0.06879492849111557, + -0.37424594163894653, + 0.028803903609514236, + 0.05349116027355194, + 0.0359492301940918, + -0.3629145622253418, + -0.17875684797763824, + -0.012246759608387947, + 0.2744927704334259, + -0.010421697050333023, + -0.19415415823459625, + 0.005668101832270622, + 0.018326066434383392, + 0.28319111466407776, + -0.008164885453879833, + -0.07401272654533386, + -0.04154321923851967, + 0.030028337612748146, + -0.008959534578025341, + -0.03160349279642105, + -0.0191870778799057, + 0.044875819236040115, + 0.052173007279634476, + 0.012135458178818226, + 0.008775291964411736, + 0.005302258301526308, + 0.009224606677889824, + -0.07574712485074997, + 0.06096252053976059, + 0.02645082212984562, + 0.05135556682944298, + 0.021985528990626335, + 0.0076704383827745914, + 0.02961125783622265, + -0.07608609646558762, + -0.17564956843852997, + 0.03679918497800827, + -0.2696506083011627, + 0.0627906322479248, + 0.031165480613708496, + 0.01799822598695755, + 0.02351829782128334, + 0.015595306642353535, + -0.25137314200401306, + -0.011266927234828472, + 0.04895596578717232, + 0.01718883402645588, + 0.0009224268142133951, + 0.021923478692770004, + 0.044791676104068756, + 0.079147569835186, + 0.02014082670211792, + -0.0003547854721546173, + -0.02535748854279518, + -0.029639363288879395, + -0.01965961419045925, + -0.37630724906921387, + 0.01674639992415905, + 0.01316642016172409, + -0.025120021775364876, + -0.12474260479211807, + 0.059980470687150955, + 0.036066047847270966, + -0.15973420441150665, + -0.010871605016291142, + 0.014708316884934902, + -0.2174367904663086, + 0.012985467910766602, + -0.03782057762145996, + -0.003427069401368499, + -0.011010636575520039, + 0.02433733455836773, + 0.08641276508569717, + -0.004630533047020435, + 0.019430357962846756, + -0.02088969387114048, + -0.06182911619544029, + 0.02577812969684601, + 0.015741532668471336, + 0.04723552614450455, + -0.003783567575737834, + 0.11646346747875214, + 0.01827184483408928, + -0.0999741181731224, + -0.0031216999050229788, + -0.002268272452056408, + -0.019456079229712486, + -0.003156653605401516, + 0.0067732855677604675, + 0.027299508452415466, + 0.06979037076234818, + 0.013329057022929192, + -0.016705401241779327, + 0.33774301409721375, + 0.007617524825036526, + 0.044453222304582596, + 0.0016282782889902592, + 0.0010982973035424948, + 0.04183036834001541, + 0.016857653856277466, + 0.006673034280538559, + -0.0187662523239851, + 0.0037163379602134228, + -0.04568779841065407, + -0.007807960733771324, + 0.016653010621666908, + 0.0033014933578670025, + 0.015063234604895115, + 0.012843966484069824, + -0.012042546644806862, + 0.016909126192331314, + 0.022089935839176178, + -0.002550398698076606, + 0.04166745766997337, + -0.0014742743223905563, + -0.010846617631614208, + -0.12333541363477707, + 0.0018612967105582356, + 0.04913188889622688, + -0.029431112110614777, + 0.01824735291302204, + 0.10425490140914917, + -0.08880072832107544, + 0.03029320202767849, + 0.018876856192946434, + 0.016104502603411674, + 0.00882721971720457, + 0.0029782119672745466, + 0.007922517135739326, + -0.02030068263411522, + -0.029835309833288193, + 0.006661414168775082, + -0.04313879832625389, + -0.001850730157457292, + -0.0035070034209638834, + -0.0070700813084840775, + 0.009637435898184776, + -0.016844747588038445, + -0.026075454428792, + 0.0030682040378451347, + 0.004208600614219904, + -0.005515689495950937, + -0.018976539373397827, + -0.019196776673197746, + -0.008948019705712795, + 0.016215825453400612, + 0.00296461652033031, + 0.14222395420074463, + -0.029066482558846474, + -0.011013337410986423, + -0.01267730537801981, + -0.004976287949830294, + -0.016607511788606644, + -0.0005681798211298883, + -0.012520174495875835, + -0.0015903630992397666, + -0.0013642794219776988, + -0.21956196427345276, + -0.0011431180173531175, + -0.0008808697457425296, + -0.022889399901032448, + 0.024718068540096283, + -0.054929111152887344, + -0.015585094690322876, + -0.018188318237662315, + -0.0008287815726362169, + -0.01957552134990692, + 0.10818513482809067, + -0.0034382494632154703, + -0.02667389065027237, + -0.01304248720407486, + -0.0034645304549485445, + -0.008519704453647137, + -0.015123830176889896, + -0.008219013921916485, + -0.009952309541404247, + -2.3375787350232713e-05, + -0.012512428686022758, + -0.001955948770046234, + -0.0029842876829206944, + -0.004291659686714411, + 0.006655955221503973, + 0.007771315053105354, + 0.014132227748632431, + -0.007390063256025314, + -0.024650415405631065, + -0.022503213956952095, + 0.0032607221510261297, + -0.008497492410242558, + 0.00860870536416769, + 0.002819088753312826, + -0.01841069757938385, + -0.010009711608290672, + -0.2912862300872803, + 0.017160022631287575, + 0.11349690705537796, + -0.027656083926558495, + -0.04482223838567734, + -0.019336597993969917, + 0.07413014769554138, + 0.014554106630384922, + 0.020965611562132835, + -0.028231356292963028, + -0.0582813061773777, + 0.05617539584636688, + -0.05042734369635582, + 0.025630727410316467, + -0.0956532284617424, + -0.14554104208946228, + -0.020851148292422295, + 0.006990485824644566, + 0.08457829803228378, + -0.11314752697944641, + 0.004020951222628355, + -0.03477870300412178, + 0.005594289395958185, + 0.011181964538991451, + 0.010988114401698112, + 0.019416088238358498, + 0.026451971381902695, + -0.00452260859310627, + 0.0004952011513523757, + 0.012377702631056309, + -0.0063480171374976635, + 0.0256175734102726, + -0.020753338932991028, + 0.03223377838730812, + -0.1147943064570427, + -0.009170151315629482, + 0.015267477370798588, + -0.0009072314132936299, + -0.1621374636888504, + 0.022807778790593147, + 0.007394107989966869, + 0.01378557924181223, + -0.10719677805900574, + -0.000919080339372158, + -0.006567052565515041, + -0.007409179583191872, + -0.007469762582331896, + -0.004784661345183849, + -0.03967805579304695, + 0.015857066959142685, + -0.02015744335949421, + 0.056037548929452896, + 0.03962035849690437, + 0.08429893851280212, + 0.022117067128419876, + -0.2675061821937561, + 0.016738418489694595, + 0.0037785861641168594, + 0.004771686624735594, + -0.134505033493042, + -0.010618447326123714, + -0.004784524440765381, + 0.014044507406651974, + -0.03105556219816208, + 0.05049083009362221, + 0.012162688188254833, + 0.005920265335589647, + 0.008554516360163689, + 0.0025892227422446012, + 0.023483717814087868, + -0.20711173117160797, + 0.03360452130436897, + -0.24758699536323547, + -0.05136318504810333, + -0.015016172081232071, + 0.06466241925954819, + 0.023470288142561913, + 0.023495715111494064, + 0.004300899337977171, + 0.02461574412882328, + 0.025745516642928123, + -0.026187308132648468, + 0.08441776037216187, + -0.06955462694168091, + -0.11116205900907516, + -0.2169608771800995, + -0.004244703333824873, + -0.024184226989746094, + -0.10068271309137344, + -0.021129190921783447, + -0.021129680797457695, + -0.0054467362351715565, + 0.17416934669017792, + 0.015367642976343632, + -0.01237915363162756, + 0.024573752656579018, + 0.004588739015161991, + 0.05616860091686249, + -0.0018992060795426369, + -0.12394066900014877, + -0.03691404312849045, + -0.15878455340862274, + 0.10572423785924911, + 0.014409378170967102, + -0.008566108532249928, + -0.20319701731204987, + -0.018277373164892197, + -0.21615462005138397, + -0.11269525438547134, + -0.2767113745212555, + -0.25617966055870056, + -0.0036413148045539856, + -0.008058675564825535, + -0.051732294261455536, + -0.013052727095782757, + 0.05229722708463669, + -0.03535814583301544, + 0.3111231327056885, + -0.044130608439445496, + -0.02232682704925537, + -0.0040402463637292385, + 0.013798556290566921, + -0.07689940929412842, + -0.028940049931406975, + -0.00565366679802537, + -0.028972560539841652, + -0.007728889584541321, + 0.013665011152625084, + -0.014678380452096462, + -0.06747694313526154, + -0.06480871140956879, + -0.00028885426581837237, + -0.01525174267590046, + 0.027096102014183998, + -0.05200905352830887, + 0.0066903820261359215, + 0.0023834225721657276, + -0.002379713812842965, + -0.0208051148802042, + 0.335977703332901, + 0.03895771875977516, + -0.04814215749502182, + -0.037339694797992706, + -0.004409746266901493, + 0.07042848318815231, + -0.08318590372800827, + -0.04138712212443352, + 0.06309781968593597, + 0.007484383415430784, + 0.09696535021066666, + 0.024134323000907898, + -0.009859816171228886, + -0.06243982911109924, + 0.04630015045404434, + -0.06593744456768036, + 0.009306293912231922, + 0.5033899545669556, + 0.007804783061146736, + 0.024170484393835068, + -0.036085959523916245, + 0.016438491642475128, + 0.01678072288632393, + -0.006299734115600586, + -0.027441656216979027, + -0.014344800263643265, + 0.022293711081147194, + 0.011197407729923725, + -0.0026971842162311077, + 0.2685070335865021, + 0.01403988990932703, + -0.005100077483803034, + -0.026031343266367912, + -0.005419034510850906, + -0.014735087752342224, + -0.0283498577773571, + 0.002656748052686453, + -0.07137783616781235, + 0.02235356532037258, + -0.02970476634800434, + 0.20672672986984253, + 0.017398398369550705, + 0.02438206970691681, + 0.025746773928403854, + -0.03279582038521767, + 0.043908532708883286, + -0.003417646512389183, + 0.020200302824378014, + 0.007243862375617027, + -0.004560714587569237, + -0.01142876222729683, + -0.028091270476579666, + -0.2949703335762024, + 0.0729827880859375, + 0.004566277377307415, + 0.16689160466194153, + 0.034872010350227356, + -0.09590360522270203, + -0.13309867680072784, + 0.06429398059844971, + 0.04174232855439186, + -0.022723963484168053, + -0.04695400968194008, + 0.013115685433149338, + 0.013574879616498947, + 0.04794493317604065, + -0.015077140182256699, + 0.09493618458509445, + 0.008845972828567028, + 0.020302923396229744, + 0.02037016488611698, + 0.009083293378353119, + 0.0747746080160141, + -0.008078188635408878, + 0.024796344339847565, + -0.015212535858154297, + -0.005867444910109043, + 0.08309170603752136, + 0.03676094114780426, + 0.07232356816530228, + -0.3577176630496979, + 0.0013658110983669758, + -0.0009247250854969025, + 0.02284996211528778, + 0.012630275450646877, + 0.013745593838393688, + 0.003447894938290119, + 0.03563565015792847, + -0.031025355681777, + -0.07258180528879166, + -0.13482442498207092, + -0.029425248503684998, + -0.014927731826901436, + 0.045984312891960144, + -0.0176406130194664, + -0.22678181529045105, + -0.025248311460018158, + -0.11617762595415115, + -0.056157518178224564, + 0.009453062899410725, + -0.34616726636886597, + 0.05691010504961014, + -0.32302799820899963, + -0.026544231921434402, + -0.007374088745564222, + -0.07682909071445465, + -0.021214107051491737, + -0.07102422416210175, + 0.02693488635122776, + 0.014817211776971817, + 0.015572831965982914, + 0.04313618317246437, + -0.1277216374874115, + 0.02174532599747181, + -0.0226149819791317, + -0.00010956164624076337, + 0.023728065192699432, + 0.008212783373892307, + 0.010561724193394184, + -0.011036543175578117, + -0.022485855966806412, + 0.008243439719080925, + -0.03383245691657066, + -0.5630682110786438, + 0.0015974265988916159, + -0.28416821360588074, + 0.04123701527714729, + -0.0042976438999176025, + 0.03786511346697807, + 0.01862393692135811, + -0.04082413762807846, + -0.05792848393321037, + 0.0068894242867827415, + 0.0024085959885269403, + 0.001471342402510345, + 0.030681759119033813, + -0.026314062997698784, + 0.0555737242102623, + 0.03169534355401993, + 0.0031395808327943087, + 0.018701769411563873, + -0.5604594945907593, + 0.01526441890746355, + -0.00621993700042367, + 0.0009401043644174933, + 0.01587403193116188, + 0.030135583132505417, + -0.007350685074925423, + 0.006527469493448734, + 0.016000108793377876, + -0.042957425117492676, + 0.018247080966830254, + 0.0025622656103223562, + -0.03169511258602142, + 0.09235119074583054, + -0.013365034945309162, + 0.01607452519237995, + 0.017734844237565994, + 0.05609896034002304, + 0.04819876700639725, + -0.0871855691075325, + 0.05157865956425667, + 0.009171447716653347, + 0.022200705483555794, + -0.005507844965904951, + -0.024452703073620796, + 0.010224574245512486, + -0.006914906669408083, + 0.004650818649679422, + 0.02167516015470028, + 0.10456826537847519, + -0.07652094960212708, + -6.050072988728061e-05, + 0.012855490669608116, + 0.022669879719614983, + 0.022655120119452477, + 0.033012885600328445, + 0.025709744542837143, + 0.00481270719319582, + 0.005920717027038336, + -0.08545156568288803, + -0.004363589454442263, + -0.01531639602035284, + 0.030760569497942924, + 0.02796284481883049, + -0.03690989315509796, + 0.044959694147109985, + -0.14276015758514404, + -0.0002254673163406551, + -0.15694372355937958, + 0.012381293810904026, + -0.021977441385388374, + 0.005496624857187271, + -0.035593707114458084, + -0.0950438603758812, + 0.03825876861810684, + 0.05915532633662224, + -0.023323312401771545, + 0.017213119193911552, + -0.03807183355093002, + 0.02619507722556591, + 0.02741156332194805, + 0.005847832188010216, + 0.0020307491067796946, + 0.025714349001646042, + -0.04780200496315956, + 0.010206928476691246, + -0.01345440000295639, + 0.029133174568414688, + -0.0014764482621103525, + 0.004046705551445484, + -0.007725241594016552, + 0.013041527941823006, + 0.0018969239899888635, + 0.002417983952909708, + -0.010975837707519531, + 0.0015862436266615987, + 0.00597577728331089, + 0.002882696921005845, + 0.02855525352060795, + -0.005954153370112181, + 0.04090835899114609, + -0.39500924944877625, + 0.03586621209979057, + -0.5250031352043152, + -0.05697731301188469, + -0.09568691998720169, + -0.07179264724254608, + 0.04683076590299606, + 0.009320023469626904, + -0.11629963666200638, + -0.0016945215174928308, + 0.01624997705221176, + -0.0063682254403829575, + 0.15033549070358276, + -0.5171176791191101, + -0.01525783073157072, + 0.016417231410741806, + -0.00303818890824914, + 0.2500321865081787, + 0.022074062377214432, + 0.01191191840916872, + 0.012274803593754768, + 0.016534989699721336, + -0.028437916189432144, + 0.04241323843598366, + -0.01824999786913395, + -0.34815871715545654, + 0.04734490439295769, + -0.06419701874256134, + -0.022288290783762932, + -0.0004865761147812009, + 0.05369419604539871, + -0.058212973177433014, + -0.2196469008922577, + 0.010950890369713306, + 0.029042819514870644, + -0.07349151372909546, + -0.0422789566218853, + 0.062069639563560486, + 0.05589267984032631, + 0.014877256006002426, + 0.04236084595322609, + 0.03975239768624306, + 0.16930873692035675, + 0.03981085494160652, + 0.11499395221471786, + 0.0271450225263834, + 0.013969083316624165, + -0.0002660648606251925, + 0.010936664417386055, + -0.18389767408370972, + -0.10237602889537811, + 0.03041323646903038, + -0.013864071108400822, + -0.015729930251836777, + 0.037400804460048676, + -0.009598327800631523, + -0.09533312171697617, + -0.014712700620293617, + 0.08537333458662033, + -0.007200485561043024, + -0.31139102578163147, + -0.06366845220327377, + 0.02039063163101673, + -0.023356139659881592, + -0.0029549277387559414, + -0.12494662404060364, + 0.011755092069506645, + -0.26468148827552795, + -0.11541861295700073, + 0.010529865510761738, + -0.05965733155608177, + -0.05945499241352081, + -0.08796169608831406, + -0.014683439396321774, + 0.008732054382562637, + 0.010073489509522915, + 0.09553763270378113, + 0.034884922206401825, + 0.018675342202186584, + -0.009549405425786972, + -0.0007051719003356993, + -0.16936513781547546, + -0.0030460187699645758, + -0.022060535848140717, + -0.06689190864562988, + 0.013926704414188862, + 0.012043816037476063, + -0.0587068572640419, + -0.03814113140106201, + 0.06235629320144653, + 0.013228330761194229, + 0.04154474660754204, + -0.08039120584726334, + 0.028436705470085144, + -0.042226389050483704, + -0.019135186448693275, + 0.03747033327817917, + -0.14261123538017273, + 0.02827540971338749, + 0.0455685593187809, + -0.031124960631132126, + -0.007588588632643223, + 0.0034326373133808374, + -0.07682976871728897, + 0.24654042720794678, + -0.014518304727971554, + -0.07052458822727203, + -0.08241941034793854, + -0.04116151109337807, + -0.048463717103004456, + -0.038745298981666565, + 0.036902472376823425, + 0.0442035011947155, + 0.05572585016489029, + -0.014312628656625748, + 0.010794793255627155, + -0.3440641760826111, + -0.5161325335502625, + 0.0005156552069820464, + -0.010257269255816936, + -0.02412656880915165, + -0.023385023698210716, + 0.05533458665013313, + -0.012186119332909584, + -0.029286568984389305, + 0.04116401448845863, + -0.044610101729631424, + -0.019175484776496887, + 0.06835268437862396, + 0.06366674602031708, + 0.0373748242855072, + 0.03804386034607887, + 0.05369521677494049, + -0.04451881721615791, + 0.0018838117830455303, + 0.34775662422180176, + 0.010958605445921421, + -0.047990139573812485, + 0.04386777803301811, + -0.10427688807249069, + 0.04417382925748825, + 4.402965714689344e-05, + 0.01935163326561451, + -0.06753949075937271, + 0.02735923044383526, + 0.01465953141450882, + 0.06198301538825035, + -0.015980403870344162, + -0.2108263075351715, + 0.008177559822797775, + 0.006046924740076065, + 0.002665479900315404, + 0.20868580043315887, + -0.013740362599492073, + 0.008203004486858845, + -0.005066391546279192, + 0.026405498385429382, + 0.01383009273558855, + 0.012581533752381802, + 0.009014940820634365, + 0.022820021957159042, + -0.008534795604646206, + 0.2603924572467804, + 0.02297227643430233, + -0.000749691273085773, + 0.044753506779670715, + 0.018596511334180832, + 0.006852792575955391, + -0.008686172775924206, + -0.10452616959810257, + 0.017021872103214264, + 0.003722329391166568, + -0.025453045964241028, + -0.011473417282104492, + -0.017907623201608658, + 0.01400628499686718, + -0.1670989990234375, + 0.004298652987927198, + -0.0022204748820513487, + 0.16521315276622772, + -0.008831127546727657, + 0.026490870863199234, + 0.006190746556967497, + -0.0177209060639143, + 0.08967147767543793, + 0.0033069502096623182, + -0.005021366756409407, + 0.0004906906979158521, + 0.0169216375797987, + -0.06124846637248993, + -0.005200678016990423, + 0.08404737710952759, + -0.010559299029409885, + -0.006309974938631058, + 0.023113396018743515, + -0.010227260179817677, + 0.001256447983905673, + 0.019783375784754753, + -0.006308461539447308, + -0.04529590904712677, + -0.00908862054347992, + -0.043217338621616364, + -0.32200074195861816, + 0.02592635713517666, + 0.030795685946941376, + -0.001814531977288425, + 0.0092842485755682, + 0.07088880985975266, + -0.0867588147521019, + 0.024099843576550484, + -0.0034031609538942575, + 0.007234686985611916, + -0.02505563199520111, + 0.0030480287969112396, + -0.019158190116286278, + 0.26473408937454224, + -0.011918547563254833, + -0.023240016773343086, + -0.06084466353058815, + -0.021916134282946587, + -0.010251260362565517, + -0.0009625791572034359, + 0.082605741918087, + -0.013018425554037094, + 0.007627277635037899, + -0.0010813736589625478, + 0.007952406071126461, + 0.06551267951726913, + -0.026020025834441185, + 0.050048135221004486, + -0.010610008612275124, + -0.02429312653839588, + -0.025263017043471336, + -0.04611891135573387, + 0.04451768472790718, + -0.08045025914907455, + -0.048037610948085785, + 0.008019295521080494, + 0.0160224549472332, + 0.002078550634905696, + -0.0202508345246315, + -0.5446130633354187, + 0.012585492804646492, + -0.0331973135471344, + 0.08371605724096298, + -0.00590998912230134, + -0.013058983720839024, + 0.027742384001612663, + 0.1042199358344078, + -0.3072803318500519, + 0.06284149736166, + -0.28551968932151794, + 0.026768438518047333, + 0.022245990112423897, + 0.018242113292217255, + -0.035077981650829315, + 0.03546127676963806, + 0.10165776312351227, + -0.025475669652223587, + -0.014933750964701176, + 0.040547240525484085, + -0.033055808395147324, + 0.011755919083952904, + -0.014459444209933281, + -0.03455093130469322, + 0.020743343979120255, + 0.02720930427312851, + -0.287664532661438, + 0.008260028436779976, + -0.009877690114080906, + 0.16657423973083496, + -0.010943812318146229, + -0.012381386943161488, + 0.030678801238536835, + 0.1559792459011078, + 0.038967035710811615, + -0.023399239405989647, + 0.015019542537629604, + -0.014201333746314049, + -0.014202176593244076, + -0.006699408870190382, + -0.13175444304943085, + 0.004643211141228676, + 0.012747463770210743, + -0.04086190089583397, + 0.06581410765647888, + -0.12192045897245407, + -0.03126347437500954, + 0.011175516061484814, + -0.00914736744016409, + -0.02883930690586567, + -0.11305265873670578, + -0.04405384883284569, + -0.009120048955082893, + -0.008926079608500004, + -0.03169447183609009, + 0.05464877560734749, + 0.25674498081207275, + 0.08497058600187302, + -0.023222925141453743, + 0.35592252016067505, + -0.006929511670023203, + 0.025255810469388962, + -0.05150032415986061, + 0.039239466190338135, + -0.07082924991846085, + -0.017321549355983734, + 0.17293211817741394, + -0.02155853807926178, + -0.014333213679492474, + 0.0031305316369980574, + -0.013490653596818447, + -0.1376512199640274, + -0.021713266149163246, + -0.029826253652572632, + -0.0011473714839667082, + -0.012434332631528378, + -0.04860873892903328, + 0.013857590034604073, + 0.0703854188323021, + 0.034528713673353195, + -0.014423011802136898, + 0.0882454589009285, + -0.091700978577137, + 0.038885727524757385, + 0.012043441645801067, + -0.03183690831065178, + -0.014495689421892166, + -0.019726552069187164, + -0.010094117373228073, + -0.004218627233058214, + -0.04413086175918579, + -0.1344134360551834, + -0.0004976870259270072, + -0.0008357573533430696, + 0.04518067091703415, + 0.046797975897789, + 0.24766182899475098, + 0.01065139751881361, + -0.0034267394803464413, + -0.016103556379675865, + -0.05139121413230896, + 0.012563390657305717, + -0.03310413286089897, + -0.030157553032040596, + 0.046670909970998764, + 0.012565785087645054, + -0.040275491774082184, + 0.023816417902708054, + -0.38536572456359863, + 0.04508889466524124, + 0.13637560606002808, + -0.010654824785888195, + 0.0459851399064064, + -0.0046302699483931065, + -0.020852191373705864, + 0.10662271827459335, + 0.06486576050519943, + 0.05727925896644592, + 0.09816201776266098, + 0.04878557100892067, + -0.16256237030029297, + 0.014547038823366165, + 0.018567964434623718, + -0.07284612208604813, + 0.017150163650512695, + 0.0246741883456707, + -0.38470372557640076, + -0.07465949654579163, + 0.03010236658155918, + -0.004397575277835131, + -0.06618984788656235, + -0.02908281609416008, + 0.060166433453559875, + -0.0020949048921465874, + 0.007689109072089195, + -0.0047390698455274105, + -0.014199030585587025, + -0.01794746331870556, + -0.02528063952922821, + 0.002218312583863735, + 0.10169881582260132, + 0.010602130554616451, + -0.06605861335992813, + -0.0008762837387621403, + -0.035027723759412766, + -0.011684391647577286, + 0.02247578091919422, + 0.17245104908943176, + 0.22525252401828766, + -0.010771296918392181, + 0.05595310404896736, + 0.06338834017515182, + -0.0038216698449105024, + -0.0032836494501680136, + 0.005779017228633165, + -0.18020786345005035, + -0.05066698044538498, + -0.0035458216443657875, + -0.10578767210245132, + -0.041712939739227295, + 0.2104150652885437, + -0.03753345459699631, + 0.013989892788231373, + 0.01988149993121624, + 0.05108603090047836, + 0.04496738687157631, + -0.3034508526325226, + 0.0226743221282959, + -0.0431472510099411, + -0.025635428726673126, + -0.18961989879608154, + -0.17218825221061707, + 0.03576141223311424, + 0.060613714158535004, + -0.011970550753176212, + -0.21435107290744781, + 0.01422552578151226, + 0.02974064089357853, + -0.061079952865839005, + 0.031064646318554878, + 0.009629320353269577, + -0.13762925565242767, + 0.01928475871682167, + 0.007310172542929649, + 0.06103459745645523, + -0.16216528415679932, + 0.03330384939908981, + 0.09578404575586319, + -0.0037327276077121496, + 0.029233848676085472, + -0.0015759399393573403, + 0.005511409603059292, + -0.4195749759674072, + 0.024169376119971275, + 0.13220365345478058, + 0.007961929775774479, + 0.008045470342040062, + 0.01919495314359665, + -0.023188553750514984, + 0.07084394991397858, + -0.24922333657741547, + 0.02011212892830372, + -0.18514998257160187, + 0.03114209696650505, + 0.09826567023992538, + 0.00592303741723299, + -0.010020115412771702, + 0.027117054909467697, + -0.214133620262146, + -0.01214816514402628, + 0.06564164906740189, + 0.02513044886291027, + 0.02132420241832733, + -0.02127540111541748, + -0.041606876999139786, + 0.04196378216147423, + -0.02060609683394432, + 0.01730814389884472, + -0.17418994009494781, + 0.03462710976600647, + -0.017470642924308777, + -0.3992193639278412, + 0.02652592957019806, + 0.025042008608579636, + 0.026447610929608345, + -0.19199316203594208, + 3.27593952533789e-05, + 0.002988220192492008, + -0.21171888709068298, + 0.03300239518284798, + 0.015727035701274872, + -0.008947308175265789, + 0.03924538940191269, + -0.08990193158388138, + 0.023726975545287132, + 0.03463870286941528, + -0.05018220469355583, + 0.13170146942138672, + 0.054000236093997955, + 0.01158218178898096, + 0.062349993735551834, + -0.014724616892635822, + 0.039657603949308395, + 0.04436490684747696, + 0.014076294377446175, + 0.07666806876659393, + 0.09630247205495834, + -0.04152659326791763, + -0.1860806941986084, + -0.07671733945608139, + 0.031573690474033356, + -0.44617798924446106, + -0.004897239152342081, + -0.03991628438234329, + 0.01880800537765026, + -0.04769768565893173, + 0.02198435738682747, + 0.01341161783784628, + -0.12239313870668411, + 0.019765935838222504, + 0.005221452098339796, + -0.025201082229614258, + 0.005132562946528196, + 0.08668412268161774, + 0.0035341952461749315, + 0.008583099581301212, + 0.032979920506477356, + 0.03324040770530701, + 0.04411708936095238, + -0.008390798233449459, + 0.040486790239810944, + -0.059673551470041275, + 0.02003314346075058, + -0.0990666076540947, + 0.03971675783395767, + 0.012021057307720184, + 0.0017271327087655663, + 0.01818535290658474, + 0.0025106174871325493, + 0.043714240193367004, + 0.019146842882037163, + -0.0041794623248279095, + 0.033447377383708954, + 0.06863203644752502, + -0.004350902978330851, + 0.0113364327698946, + -0.05825724080204964, + -0.04649435728788376, + -0.10618306696414948, + 0.02653644233942032, + 0.012514552101492882, + 0.019399365410208702, + -0.0022177041973918676, + 0.017741208896040916, + 0.04115311801433563, + 0.05122101679444313, + 0.055051617324352264, + 0.01687677949666977, + -0.03698579967021942, + 0.10053858160972595, + -0.007528421934694052, + 0.003968802746385336, + 0.02458524890244007, + -0.02144794538617134, + 0.026791265234351158, + -0.016701897606253624, + 0.014119372703135014, + -0.03460531681776047, + -0.02320348098874092, + 0.056146953254938126, + 0.028700685128569603, + -0.14820916950702667, + -0.016996873542666435, + 0.025667931884527206, + 0.08408629894256592, + 0.00034475952270440757, + 0.007573155220597982, + 0.06784884631633759, + 0.025982951745390892, + -0.08363039791584015, + -0.015748541802167892, + -0.0029514851048588753, + -0.01523523684591055, + 0.10500328987836838, + 0.3070858418941498, + -0.024624783545732498, + 0.0058471946977078915, + -0.039751242846250534, + 0.0012745993444696069, + -0.0796508714556694, + 0.024727927520871162, + 0.056764136999845505, + -0.013338261283934116, + -0.04794292524456978, + -0.02609768509864807, + -0.010784422047436237, + -0.048712026327848434, + 0.020345501601696014, + 0.0021618579048663378, + -0.0021724768448621035, + 0.03056410700082779, + -0.01633712649345398, + -0.47168225049972534, + -0.014639903791248798, + -0.012550815008580685, + 0.03358187526464462, + 0.07889427989721298, + -0.03615899011492729, + -0.002809660043567419, + -0.006953644100576639, + 0.02024337276816368, + -0.0738825723528862, + -0.006984011270105839, + -0.04472561925649643, + -0.027498915791511536, + 0.07207506150007248, + -0.09166522324085236, + -0.008861960843205452, + 0.05264359340071678, + 0.01889069564640522, + -0.1380404680967331, + -0.010141258127987385, + 0.015403619967401028, + -0.16416165232658386, + -0.03529815003275871, + 0.042106859385967255, + 0.11173021793365479, + -0.3143587112426758, + 0.011045016348361969, + 0.0012351945042610168, + 0.03840603306889534, + 0.0685538575053215, + -0.000746160454582423, + -0.028142500668764114, + 0.027154160663485527, + 0.005731801502406597, + 0.04433267563581467, + -0.8158469796180725, + 0.02226361259818077, + -0.07650655508041382, + 0.026958195492625237, + -0.005810025613754988, + -0.020102059468626976, + -0.0019310436910018325, + 0.07697021961212158, + -0.057701658457517624, + 0.05954534560441971, + 0.0027106746565550566, + -0.06311310827732086, + 0.011713752523064613, + -0.0034454476553946733, + -0.0006881420267745852, + 0.08937360346317291, + -0.0008253820124082267, + -0.031066063791513443, + -0.14708301424980164, + -0.04438449814915657, + 0.004772413522005081, + 0.05992274731397629, + 0.07473544776439667, + -0.1784757375717163, + -0.19057415425777435, + -0.014637955464422703, + -0.24898527562618256, + 0.13606221973896027, + -0.018039124086499214, + -0.047193415462970734, + -0.06526428461074829, + 0.04075757786631584, + 0.049901530146598816, + -0.008585861884057522, + 0.01616351678967476, + -3.091737016802654e-05, + 0.024283329024910927, + 0.008861682377755642, + -0.0005823548417538404, + 0.0997646301984787, + 0.051001910120248795, + 0.009473294951021671, + -0.0032046104315668344, + 0.018362928181886673, + 0.008627718314528465, + -0.4148157835006714, + -0.016077928245067596, + 0.0745391696691513, + 0.00724065862596035, + 0.08948155492544174, + 0.11626332253217697, + -0.052439428865909576, + 0.005599102005362511, + 0.002622961765155196, + 0.07586965709924698, + 0.03274847939610481, + -0.02099076844751835, + -0.04666733741760254, + -0.0013019372709095478, + 0.04945925995707512, + 0.11393380910158157, + 0.006346395239233971, + 0.04721064493060112, + 0.010331138968467712, + 0.08918803185224533, + 0.04288423806428909, + -0.09234773367643356, + 0.020141584798693657, + -3.256054696976207e-05, + -0.02799108810722828, + 0.018966441974043846, + -0.4136410355567932, + -0.07217283546924591, + 0.01840362884104252, + -0.055327851325273514, + 0.003275467548519373, + -0.017174070701003075, + -0.032178670167922974, + 0.09021560847759247, + -0.524413526058197, + 0.01994725503027439, + 0.10380692034959793, + -0.01043684035539627, + -0.00011200909648323432, + 0.01331041194498539, + 0.020127851516008377, + -0.025159789249300957, + 0.05252581834793091, + 0.04759140685200691, + 0.0032084162812680006, + -0.03579062595963478, + 0.054719552397727966, + -0.04674411937594414, + 0.028389262035489082, + 0.001127603929489851, + -0.0006243048119358718, + -0.00550495833158493, + -0.022523507475852966, + -0.024282312020659447, + 0.009519628249108791, + -0.39908328652381897, + -0.009265545755624771, + -0.00037090369733050466, + 0.06425131112337112, + -0.05998316407203674, + -0.015221518464386463, + -0.004825026262551546, + 0.11847284436225891, + -0.011302731931209564, + -0.006884834263473749, + -0.04678218811750412, + -0.012078279629349709, + 0.021638741716742516, + -0.016819776967167854, + -0.009127719327807426, + -0.002491263672709465, + 0.0016752213705331087, + -0.016600262373685837, + 0.011772023513913155, + -0.013447183184325695, + -0.020662957802414894, + -0.011593316681683064, + 0.008270744234323502, + -0.0026990456972271204, + -0.004406482446938753, + -0.023110052570700645, + -0.00208942755125463, + -0.1711198389530182, + 0.012432538904249668, + -0.0045453268103301525, + 0.024807902052998543, + -0.0035043740645051003, + -0.004001997876912355, + -0.013488625176250935, + -0.02020987868309021, + -0.01216109935194254, + -0.004432092886418104, + 0.09323672950267792, + -0.015641510486602783, + -0.019307948648929596, + 0.01117538008838892, + -0.01422040443867445, + 0.01705607771873474, + -0.0029596879612654448, + -0.0021530911326408386, + -0.006551788654178381, + 0.00429268553853035, + -0.1620807945728302, + -0.014128226786851883, + -0.005428737495094538, + -0.006771362852305174, + 0.005730633158236742, + 0.0007243106956593692, + 0.0024031582288444042, + -0.00199915561825037, + 0.006133859045803547, + -0.013380909338593483, + 0.00733462069183588, + -0.001863821060396731, + -0.0020169683266431093, + -0.014070986770093441, + -0.006501683499664068, + -0.029421553015708923, + 0.0009377509704791009, + -0.01718256250023842, + -0.05819401144981384, + -0.018859732896089554, + 0.0010356366401538253, + 0.006394123658537865, + -0.021985618397593498, + -0.01204769592732191, + -0.002014884725213051, + -0.019398409873247147, + -0.013122898526489735, + -0.017277296632528305, + -0.002270353492349386, + -0.05294327810406685, + -0.020317314192652702, + -0.018196573480963707, + -0.010375416837632656, + -0.019704729318618774, + -0.016109557822346687, + -0.0167380403727293, + -0.0285252146422863, + -0.02665277197957039, + -0.03554505482316017, + -0.00741522666066885, + -0.013580105267465115, + -0.026335405185818672, + -0.011694515123963356, + -0.004639182705432177, + -0.03996071219444275, + -0.022463932633399963, + -0.007204636000096798, + -0.021065134555101395, + -0.014410646632313728, + 0.0035447971895337105, + -0.0013098351191729307, + -0.024171002209186554, + 0.00047751085367053747, + -0.01870289072394371, + -0.06016797944903374, + -0.025703946128487587, + -0.009730588644742966, + -0.021792838349938393, + -0.024519823491573334, + -0.01843440905213356, + -0.0016325484029948711, + -0.008116388693451881, + -0.017774557694792747, + -0.04375867918133736, + -0.03893980756402016, + -0.018188582733273506, + -0.007122726645320654, + -0.028115490451455116, + -0.01821342669427395, + -0.01011319737881422, + -0.02616124413907528, + -0.013797983527183533, + -0.03202736750245094, + -0.030110370367765427, + -0.01883666031062603, + -0.01185502391308546, + -0.006012012716382742, + -0.017311619594693184, + -0.022577986121177673, + -0.02101938985288143, + 0.0025952248834073544, + -0.005058783106505871, + -0.004162575118243694, + -0.01559755764901638, + -0.017923563718795776, + -0.04231095686554909, + -0.017630560323596, + -0.011938830837607384, + -0.01587115228176117, + 0.004972478374838829, + -0.016601158306002617, + 0.15419845283031464, + 0.0009241115767508745, + 0.051028184592723846, + 0.008128340356051922, + -0.019917558878660202, + -0.0010339801665395498, + 0.022349294275045395, + -0.0072520882822573185, + 0.0017750378465279937, + -0.10526080429553986, + 0.03420695662498474, + 0.019183926284313202, + -0.0006544998032040894, + -0.0032203509472310543, + -0.01216941885650158, + -0.03561796247959137, + 0.024905826896429062, + -0.026948239654302597, + -0.01913355104625225, + -0.014459407888352871, + 0.006972283590584993, + -0.033184293657541275, + 0.04884861409664154, + -0.002296984428539872, + -0.19194477796554565, + 0.00392142403870821, + 0.009490449912846088, + -0.02687196619808674, + -0.06327224522829056, + -0.03684951737523079, + -0.0002613202668726444, + -0.012086644768714905, + 0.03630973398685455, + 0.007296048104763031, + 0.011186012998223305, + 0.0074085514061152935, + -0.020394617691636086, + -0.010585476644337177, + -0.030289918184280396, + 0.0773506686091423, + 0.008841303177177906, + 0.019423579797148705, + 0.001184571417979896, + 0.005553434602916241, + 0.015373414382338524, + -0.0027953842654824257, + 0.013204757124185562, + 0.029097743332386017, + 0.012627501040697098, + 0.02102004364132881, + -0.09469914436340332, + -0.023324014618992805, + 0.029243655502796173, + 0.002979277865961194, + -0.004492263309657574, + 0.20549021661281586, + -0.3244459927082062, + 0.025892559438943863, + 0.009620796889066696, + -0.05520407855510712, + -0.02271144650876522, + 0.008378816768527031, + -0.0671214610338211, + -0.016056722030043602, + -0.02355658821761608, + 0.0005429868469946086, + -0.007960098795592785, + 0.02513299137353897, + -0.13005328178405762, + -0.0025323680602014065, + -0.02197088487446308, + -0.02404806576669216, + 0.08261960744857788, + 0.17078880965709686, + 0.02880753017961979, + -0.03642067685723305, + 0.021994341164827347, + -0.012368184514343739, + -0.10681373625993729, + 0.16371481120586395, + 0.17881983518600464, + -0.10202010720968246, + -0.08641688525676727, + -0.1259487271308899, + 0.06907707452774048, + 0.023792706429958344, + -0.02534419298171997, + 0.016984017565846443, + -0.06743635982275009, + 0.08445960283279419, + -0.08037827908992767, + -0.11935994029045105, + -0.31716489791870117, + -0.01860150322318077, + 0.060669515281915665, + -0.06137414649128914, + 0.09878886491060257, + 0.01794014871120453, + 0.12382296472787857, + -0.016424886882305145, + 0.09045679122209549, + -0.02998783066868782, + -0.00972777884453535, + -0.024124544113874435, + 0.09879253059625626, + 0.05500243604183197, + -0.06635259836912155, + 0.11268552392721176, + 0.011751363053917885, + -0.04690232127904892, + -0.025168607011437416, + 0.088335320353508, + -0.1140628531575203, + 0.04129032790660858, + -0.04258979484438896, + -0.0903872698545456, + 0.008473021909594536, + -0.026690304279327393, + -0.051559556275606155, + -0.05481572076678276, + -0.05251916125416756, + -0.0018165932269766927, + 0.09836867451667786, + 0.0054859439842402935, + 0.06432581692934036, + 0.10621821135282516, + -0.019325286149978638, + -0.028727786615490913, + 0.014013150706887245, + -0.008022608235478401, + -0.006281842477619648, + -0.0297000203281641, + 0.01525485422462225, + -0.4346403479576111, + 0.07787995040416718, + -0.25380268692970276, + 0.05261845141649246, + 0.010875157080590725, + 0.0014149334747344255, + 0.05021188035607338, + -0.24382442235946655, + 0.0807114690542221, + 0.022907381877303123, + 0.006440790370106697, + -0.017028095200657845, + 0.001552293193526566, + 0.05961666256189346, + -0.14113056659698486, + 0.03398876264691353, + -0.005411976482719183, + -0.014025667682290077, + -0.5433799624443054, + 0.019015472382307053, + 0.04091138765215874, + 0.05059061944484711, + 0.0274446289986372, + -0.010288042947649956, + -0.001335533568635583, + -0.013533512130379677, + 0.018798377364873886, + -0.04099345579743385, + 0.0031264263670891523, + -0.21071769297122955, + -0.014384736306965351, + -0.1045387014746666, + -0.014340974390506744, + 0.001986369490623474, + -0.04118456318974495, + -0.10952988266944885, + 0.049147430807352066, + -0.08382093161344528, + -0.1741400957107544, + -0.0885215476155281, + -0.10934099555015564, + 0.05553343519568443, + 0.02434251271188259, + 0.006634524557739496, + -0.0017163373995572329, + 0.0185443926602602, + 0.06250902265310287, + -0.17145656049251556, + -0.07543934881687164, + 0.026583310216665268, + 0.01634727604687214, + 0.003603539662435651, + -0.2817271649837494, + 0.03882112354040146, + 0.011341865174472332, + 0.00826666783541441, + 0.050427842885255814, + -0.22358834743499756, + 0.06419781595468521, + 0.03245265409350395, + -0.04503164440393448, + -0.023194484412670135, + -0.027968740090727806, + 0.08563586324453354, + 0.07954753190279007, + -0.08513130992650986, + 0.02850884199142456, + 0.008976672776043415, + 0.07886530458927155, + 0.0022273347713053226, + -0.09540755301713943, + 0.032016951590776443, + -0.05196075513958931, + 0.10555616766214371, + 0.07629868388175964, + 0.039732079952955246, + -0.0029798501636832952, + 0.014692343771457672, + 0.09200941026210785, + -0.04299614951014519, + -0.023488566279411316, + -0.01851060427725315, + 0.09257487207651138, + 0.055612049996852875, + 0.06423109769821167, + -0.28587806224823, + -0.09950444847345352, + 0.10397437959909439, + 0.025166453793644905, + -0.03235514089465141, + -0.033381711691617966, + 0.1513858139514923, + 0.06468874961137772, + 0.01928441785275936, + 0.0032701045274734497, + -0.0579083226621151, + -0.022929169237613678, + 0.012971373274922371, + -0.018524186685681343, + -0.06484643369913101, + 0.012233717367053032, + 0.06590451300144196, + -0.04558677598834038, + 0.05253027006983757, + 0.048656731843948364, + -0.2288871705532074, + 0.037114787846803665, + -0.20519588887691498, + 0.0058607361279428005, + -0.002009372925385833, + -0.006671734619885683, + -0.07107856124639511, + -0.07407436519861221, + 0.03941629081964493, + 0.0447598397731781, + 0.03509354963898659, + -0.061107732355594635, + -0.09305761009454727, + -0.012180411256849766, + 0.04902016744017601, + 0.07974442094564438, + -0.016854410991072655, + 0.005089411046355963, + -0.08127597719430923, + 0.03258403390645981, + 0.039813362061977386, + -0.01668727956712246, + 0.027226485311985016, + -0.029213925823569298, + -0.008598217740654945, + 0.00931101106107235, + 0.026936721056699753, + -0.03083401545882225, + -0.05799110606312752, + -0.008277476765215397, + -0.014854338951408863, + -0.20012643933296204, + 0.012290815822780132, + 0.007194168865680695, + 0.06858328729867935, + -0.3296163082122803, + -0.11424986273050308, + 0.009912200272083282, + -0.06211454048752785, + 0.0007546336855739355, + 0.03507614880800247, + 0.10649498552083969, + -0.03036407195031643, + 0.0646015852689743, + -0.01595110446214676, + -0.16919563710689545, + 0.0013865949586033821, + -0.08339446783065796, + 0.06962471455335617, + 0.016058098524808884, + -0.04729780554771423, + 0.010602935217320919, + 0.01470863912254572, + 0.06903938204050064, + 0.014901719056069851, + -0.15120048820972443, + 0.016727851703763008, + 0.05003673583269119, + 0.04370126873254776, + 0.029703885316848755, + 0.021875420585274696, + 0.026293285191059113, + -0.01048936415463686, + 0.00040810942300595343, + -0.015616541728377342, + -0.062451593577861786, + 0.010016348212957382, + -0.06790193170309067, + -0.02077331207692623, + 0.007985175587236881, + -0.04435744881629944, + 0.06920231133699417, + 0.018344474956393242, + 0.028591370210051537, + 0.021957838907837868, + 0.0017570338677614927, + 0.036665257066488266, + 0.015438515692949295, + -0.0006347382441163063, + 0.04621066153049469, + -0.001942177303135395, + 0.010664877481758595, + -0.016754357144236565, + 0.006541184149682522, + -0.027716301381587982, + -0.0058586387895047665, + -0.005346015095710754, + 0.020482052117586136, + 0.06882552057504654, + 0.0026622572913765907, + 0.016321638599038124, + 0.017728103324770927, + -0.13356441259384155, + 0.030281176790595055, + 1.0354949154134374e-05, + 0.050639618188142776, + 0.0013030078262090683, + -0.11136802285909653, + -0.006832807790488005, + -0.09628921747207642, + 0.046699415892362595, + 0.002175685251131654, + 0.008100612089037895, + 0.012449901551008224, + -0.01713990420103073, + -0.000769267207942903, + 0.022544430568814278, + -0.0018787183798849583, + -0.014189678244292736, + 0.37042510509490967, + -0.030317893251776695, + 0.012663356028497219, + -0.04071582853794098, + 0.01653047651052475, + 0.06578584760427475, + 0.005606585182249546, + 0.0029362838249653578, + -0.02035594917833805, + 0.016131827607750893, + -0.06512665003538132, + 0.020292088389396667, + 0.12818951904773712, + -0.00017647731874603778, + 0.0004811069811694324, + 0.013025660999119282, + -0.006004344671964645, + 0.011330580338835716, + 0.0021733916364610195, + -0.0026290342211723328, + 0.008579215034842491, + -0.017107143998146057, + 0.0032798980828374624, + 0.21415431797504425, + -0.011049880646169186, + 0.04915957152843475, + -0.01152863260358572, + 0.01988764852285385, + -0.30189022421836853, + 0.1491061896085739, + 0.022540517151355743, + 0.02323656715452671, + -0.0028044115751981735, + -0.02501249685883522, + 0.0016759912250563502, + 0.023405946791172028, + 0.0865691602230072, + 0.0056661744602024555, + 0.2334042638540268, + -0.05771901085972786, + 0.03428330272436142, + -0.05191519856452942, + 0.025708407163619995, + -0.11474912613630295, + 0.05345827341079712, + 0.050046734511852264, + -0.03785427287220955, + 0.02726786397397518, + 0.008640051819384098, + -0.05810163915157318, + 0.19147679209709167, + 0.12065602838993073, + -0.08667072653770447, + -0.12831886112689972, + 0.027053257450461388, + -0.1771622896194458, + -0.2615586817264557, + 0.112942636013031, + 0.002398239215835929, + 0.00907410029321909, + 0.059947770088911057, + 0.040937639772892, + 0.003431124845519662, + 0.012721046805381775, + -0.10228776186704636, + 0.04169567674398422, + -0.04826785624027252, + -0.021415220573544502, + 0.027615519240498543, + 0.16087181866168976, + 0.03552674129605293, + -0.36409878730773926, + 0.0015418739058077335, + 0.03940089792013168, + -0.12929502129554749, + 0.017082052305340767, + -0.07193783670663834, + 0.10395099222660065, + -0.2240910828113556, + -0.003303584409877658, + -0.0074868109077215195, + -0.13708709180355072, + 0.2098008245229721, + 0.013808795250952244, + -0.03606148064136505, + 0.001965852687135339, + 0.04186573252081871, + 0.02105732634663582, + -0.11873909085988998, + -0.08529136329889297, + 0.0060731275007128716, + 0.04803553968667984, + 0.07665349543094635, + 0.026997262611985207, + 0.05191565304994583, + 0.09013131260871887, + 0.013081093318760395, + 0.04667182266712189, + -0.19899451732635498, + 0.004642056301236153, + 0.0025570227298885584, + -0.2640555500984192, + 0.008254006505012512, + 0.05971720814704895, + -0.002980671590194106, + 0.0011313167633488774, + -0.004445134196430445, + 0.01951296627521515, + -0.006634386721998453, + -0.008033698424696922, + 0.012400158680975437, + -0.15906694531440735, + 0.007047838997095823, + 0.0003521084145177156, + -0.00517050176858902, + -0.0003226286207791418, + -0.01226231548935175, + -0.06750697642564774, + -0.03061128593981266, + -0.0027100055012851954, + 0.004726986400783062, + 0.010185977444052696, + 0.021205933764576912, + -0.05105980113148689, + -0.006725164130330086, + 0.26042309403419495, + 0.003935054875910282, + 0.009450466372072697, + -0.009512278251349926, + 0.036205559968948364, + 0.0066987741738557816, + 0.05687355250120163, + -0.0070350514724850655, + 0.021287698298692703, + 0.004246287513524294, + -0.004053668584674597, + 0.0030501342844218016, + -0.003596516093239188, + 0.00571554945781827, + 0.039099883288145065, + 0.06648323684930801, + 0.011140268296003342, + 0.002779693342745304, + 0.0004113377653993666, + 0.0019621821120381355, + 0.002047213725745678, + -9.034215327119455e-05, + 0.006674906238913536, + -0.024464793503284454, + 4.372629337012768e-05, + 0.04560312256217003, + 0.029951298609375954, + 0.0053787860088050365, + 0.010052027180790901, + 0.0018156497972086072, + 0.001613074098713696, + -0.3710610568523407, + 0.18385423719882965, + 0.0197732076048851, + -2.409513217571657e-05, + 0.043657880276441574, + 0.029824273660779, + -0.0015272254822775722, + -0.0009817760437726974, + 0.030571524053812027, + 0.05133187025785446, + 0.021092001348733902, + -0.022430723533034325, + -0.011050102300941944, + -0.01653454266488552, + 0.00856624636799097, + 0.007617316208779812, + 0.023697074502706528, + -0.00541776092723012, + -0.06940567493438721, + -0.024501511827111244, + 0.0029131292831152678, + 0.005110545549541712, + 0.02394089475274086, + 0.009317552670836449, + -0.05198051407933235, + -0.14872707426548004, + -0.03553030639886856, + 0.05354774370789528, + 0.053996339440345764, + 0.016679847612977028, + -0.4505158066749573, + 0.006403166800737381, + -0.014287465251982212, + 0.010499212890863419, + 0.00510875741019845, + 0.0230255089700222, + -0.04791099205613136, + -0.08405473828315735, + -0.00807158276438713, + -0.016310568898916245, + -0.018034789711236954, + -0.03381670266389847, + 0.038599055260419846, + 0.01189411524683237, + 0.0038598189130425453, + 0.0077203805558383465, + -0.0006835742969997227, + 0.3038807809352875, + 0.00930703990161419, + -0.017654214054346085, + -0.029550395905971527, + 0.0014829621650278568, + -0.010562432929873466, + -0.011867706663906574, + -0.008104459382593632, + 0.008003979921340942, + -0.028282882645726204, + 0.00898829661309719, + -0.04963170364499092, + 0.014971665106713772, + 0.028662119060754776, + 0.055792808532714844, + 0.018142173066735268, + 0.029526766389608383, + 0.04726170003414154, + 0.020290115848183632, + -0.01347910612821579, + -0.027794860303401947, + -0.033374592661857605, + 0.05699307844042778, + -0.005888971965759993, + 0.009723466821014881, + 0.011825029738247395, + 0.0005665962235070765, + -0.22433574497699738, + 0.04777664318680763, + 0.054696254432201385, + 0.06447272002696991, + 0.006656138692051172, + -0.2656468152999878, + -0.006602808367460966, + -0.04309352487325668, + 0.024392882362008095, + -0.046948980540037155, + 0.17317010462284088, + -0.014694501645863056, + 0.09150391072034836, + 0.05414793640375137, + -0.0034523033536970615, + -0.029682809486985207, + -0.11646991223096848, + 0.036394182592630386, + -0.008510537445545197, + -0.09555189311504364, + 0.012331446632742882, + 0.022554755210876465, + 0.037040166556835175, + 0.011939534917473793, + -0.035405583679676056, + -0.008284371346235275, + 0.008629710413515568, + -0.0017152110813185573, + -0.01656493730843067, + 0.02205522358417511, + -0.008015291765332222, + -0.02198217809200287, + -0.08165504783391953, + 0.018647879362106323, + 0.010489191859960556, + 0.0009643095545470715, + 0.08301698416471481, + 0.00795030314475298, + -0.08973152935504913, + 0.05324552580714226, + 0.0187348835170269, + 0.00770497927442193, + 0.016434336081147194, + 0.0031714467331767082, + 0.031489044427871704, + -0.01682765781879425, + -0.0006042059976607561, + 0.006229344755411148, + 0.0031935630831867456, + -0.03694210946559906, + -0.027148112654685974, + 0.03319454565644264, + 0.013541879132390022, + 0.04362545907497406, + 0.010766182094812393, + 0.01287879142910242, + 0.02723391354084015, + 0.01831277459859848, + -0.0028144901152700186, + 0.0317537821829319, + -0.05053209140896797, + 0.03341667726635933, + 0.009338690899312496, + 0.030376508831977844, + 0.028512636199593544, + 0.002190604107454419, + 0.031132254749536514, + 0.04174429178237915, + 0.025147251784801483, + 0.02602408640086651, + 0.022863827645778656, + 0.024160150438547134, + 0.04043813422322273, + 0.011693909764289856, + 0.008020071312785149, + 0.010814648121595383, + 0.014862221665680408, + 0.043966785073280334, + 0.04133215174078941, + 0.03920775279402733, + 0.02128027193248272, + -0.0024078795686364174, + 0.03185494989156723, + 0.030951442196965218, + 0.008766901679337025, + -0.0013500713976100087, + 0.012680909596383572, + 0.01911563239991665, + 0.02226334996521473, + 0.03873631730675697, + 0.005242412444204092, + 0.02335301972925663, + 0.00577192846685648, + 0.0019918885082006454, + 0.019501060247421265, + 0.048295676708221436, + 0.027288099750876427, + 0.03500128164887428, + 0.032504353672266006, + 0.03619033470749855, + 0.022762063890695572, + 0.014124974608421326, + 0.04055529460310936, + 0.040181197226047516, + 0.04843837395310402, + 0.019578352570533752, + 0.04370861127972603, + 0.024640914052724838, + 0.027013463899493217, + 0.04700532928109169, + 0.018523193895816803, + 0.03569294884800911, + 0.031140455976128578, + 0.010298499837517738, + 0.03979840502142906, + 0.015059049241244793, + 0.020604899153113365, + 0.010335667058825493, + 0.02557498589158058, + 0.015946611762046814, + 0.018900645896792412, + 0.05494159087538719, + 0.015756357461214066, + 0.0452926866710186, + 0.04820817708969116, + -0.0183499027043581, + 0.04002442955970764, + -0.08226092159748077, + -0.034417178481817245, + 0.059122342616319656, + 0.028960591182112694, + -0.020427608862519264, + -0.043222296983003616, + 0.023134637624025345, + -0.014232538640499115, + -0.06970997899770737, + -0.0035826240200549364, + -0.015384080819785595, + -0.0695020854473114, + 0.03645527362823486, + 0.013986784033477306, + -0.027729706838726997, + -0.05711805075407028, + -0.0763891413807869, + -0.16338491439819336, + -0.02358265034854412, + -0.004730133805423975, + 0.022057903930544853, + -0.011578230187296867, + 0.040772147476673126, + -0.059327173978090286, + -0.03819728270173073, + -0.050089117139577866, + -0.005152902565896511, + -0.3071111738681793, + -0.010683669708669186, + 0.030922774225473404, + 0.08924981951713562, + 0.005679265595972538, + 0.06334424018859863, + 0.016136568039655685, + -0.02575727365911007, + -0.012562219053506851, + 0.007206748705357313, + -0.1373208612203598, + -0.010450832545757294, + -0.05991309881210327, + -0.006700845435261726, + -0.006468744482845068, + -0.02040017955005169, + -0.010068708099424839, + 0.008442427963018417, + 0.012259873561561108, + -0.002103718463331461, + -0.019605906680226326, + -0.010690353810787201, + 0.0005222380859777331, + -0.015031278133392334, + -0.012983204796910286, + -0.03552224859595299, + -0.007792052812874317, + -0.035602111369371414, + -0.03479204699397087, + -0.02480080910027027, + -0.05733964219689369, + 4.38804054283537e-05, + -0.021825626492500305, + -0.03287259489297867, + -0.05437042564153671, + -0.007981077767908573, + 0.023045696318149567, + 0.05785335600376129, + 0.03685669228434563, + 0.04314129799604416, + -0.005843586288392544, + -0.024806369096040726, + -0.02562016434967518, + 0.0015172295970842242, + -0.01568800024688244, + -0.005925294477492571, + 0.010173594579100609, + 0.06834683567285538, + 0.024159085005521774, + -0.009547322988510132, + 0.014080812223255634, + 0.013578452169895172, + 0.035671167075634, + 0.01240566186606884, + -0.021352441981434822, + 0.05245270952582359, + -0.008943279273808002, + -0.010131126269698143, + 0.02976749651134014, + 0.0600045844912529, + 0.0014893191400915384, + 0.03796907886862755, + 0.01258794590830803, + -0.025344882160425186, + 0.14140591025352478, + 0.028354406356811523, + 0.0035325682256370783, + 0.05017172172665596, + 0.01994139887392521, + 0.03679897263646126, + -0.009579945355653763, + -0.012607194483280182, + -0.00034231581958010793, + 0.00046832446241751313, + 0.057916246354579926, + 0.02351403795182705, + 0.06157909706234932, + 0.00789523497223854, + -0.018361341208219528, + 0.0018971840618178248, + -0.007180131506174803, + -0.0010631990153342485, + -0.03140748664736748, + -0.028505641967058182, + 0.010669395327568054, + -0.036474280059337616, + 0.01703447848558426, + 0.04667484760284424, + -0.007303370162844658, + 0.01768752932548523, + 0.012412219308316708, + 0.013702306896448135, + 0.07651616632938385, + 0.05469715967774391, + 0.013292597606778145, + -0.006288900971412659, + 0.0215559434145689, + 0.010094149969518185, + -0.024216346442699432, + -0.15225785970687866, + 0.05467289313673973, + 0.019871067255735397, + 0.04662928730249405, + 0.05072600021958351, + -0.011824453249573708, + -0.028083933517336845, + 0.013322187587618828, + -0.044827401638031006, + 0.05955006927251816, + -0.006152187939733267, + 0.013426700606942177, + -0.014220507815480232, + 0.022510837763547897, + 0.019426455721259117, + -0.05546477064490318, + -0.49202534556388855, + 0.026985207572579384, + -0.08852843940258026, + 0.07166163623332977, + 0.05509938299655914, + -0.42284780740737915, + -0.05131356418132782, + 0.0196990966796875, + -0.008681846782565117, + 0.02739996463060379, + 0.0010900507913902402, + 0.04289104416966438, + -0.06694932281970978, + 0.05930810049176216, + -0.02174360118806362, + 0.03464379161596298, + 0.018284866586327553, + 0.018807150423526764, + 0.019874336197972298, + -0.03665176033973694, + -0.2980017066001892, + 0.050937239080667496, + -0.013874954544007778, + -0.0229057464748621, + 0.016420641914010048, + 0.024160616099834442, + -0.10750921070575714, + -0.010134756565093994, + 0.026874780654907227, + 0.007151094265282154, + 0.06304068863391876, + -0.11811652034521103, + -0.12590888142585754, + 0.031846947968006134, + -0.06898463517427444, + 0.03395693376660347, + -0.00010166154243052006, + -0.19019480049610138, + 0.06616076827049255, + -0.035927142947912216, + 0.08526375889778137, + 0.0015017242403700948, + -0.009137739427387714, + 0.04529058188199997, + -0.23621641099452972, + 0.02148340456187725, + -0.02741178683936596, + -0.20779411494731903, + ] + value = numpy.array(list_value, dtype=numpy.float32).reshape((64, 64, 1, 1)) + tensor = numpy_helper.from_array(value, name="onnx::Conv_504") + + initializers.append(tensor) + + list_value = [ + 5.195802688598633, + 0.940099835395813, + -7.016428470611572, + 5.185446739196777, + -4.134859085083008, + 2.0121846199035645, + 5.215719223022461, + 3.371406078338623, + 3.7616095542907715, + -3.6593239307403564, + 15.99945068359375, + 3.306276321411133, + 5.790191173553467, + 6.33050537109375, + 3.4512906074523926, + 2.5531861782073975, + 4.278702259063721, + 4.350361347198486, + 8.025779724121094, + -2.8830037117004395, + 2.915111541748047, + 3.592482805252075, + 5.810481071472168, + 3.4743332862854004, + 3.5245680809020996, + 1.8243598937988281, + 8.069726943969727, + 1.401036024093628, + 5.110081672668457, + -12.873579978942871, + 10.977816581726074, + 5.909627437591553, + -0.4007779359817505, + -20.147268295288086, + 6.649413585662842, + 3.325921058654785, + 5.84471321105957, + 4.47447395324707, + 3.754193067550659, + -5.167671203613281, + 3.2778055667877197, + -9.067073822021484, + 2.6243438720703125, + 1.7002031803131104, + 5.476454734802246, + 2.510835647583008, + 3.856968402862549, + 2.3172807693481445, + 12.462139129638672, + 7.355924129486084, + 4.140628814697266, + 4.807559967041016, + 5.7524309158325195, + 4.128836154937744, + 11.4532470703125, + -12.482564926147461, + 5.590144157409668, + 0.9172697067260742, + 4.356811046600342, + 0.9934853315353394, + -4.3548994064331055, + 15.853201866149902, + -5.241130828857422, + 5.9644365310668945, + ] + value = numpy.array(list_value, dtype=numpy.float32) + tensor = numpy_helper.from_array(value, name="onnx::Conv_505") + + initializers.append(tensor) + + # inputs + + inputs.append(make_tensor_value_info("input", 1, ["batch_size", 3, 32, 32])) + + # outputs + + outputs.append(make_tensor_value_info("/layer1/layer1.0/relu/Relu_output_0", 1, ["batch_size", 64, 8, 8])) + + # nodes + + node = make_node( + "Conv", + ["input", "onnx::Conv_501", "onnx::Conv_502"], + ["/conv1/Conv_output_0"], + name="/conv1/Conv", + dilations=[1, 1], + group=1, + kernel_shape=[7, 7], + pads=[3, 3, 3, 3], + strides=[2, 2], + domain="", + ) + nodes.append(node) + + node = make_node("Relu", ["/conv1/Conv_output_0"], ["/relu/Relu_output_0"], name="/relu/Relu", domain="") + nodes.append(node) + + node = make_node( + "MaxPool", + ["/relu/Relu_output_0"], + ["/maxpool/MaxPool_output_0"], + name="/maxpool/MaxPool", + ceil_mode=0, + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + strides=[2, 2], + domain="", + ) + nodes.append(node) + + node = make_node( + "Conv", + ["/maxpool/MaxPool_output_0", "onnx::Conv_504", "onnx::Conv_505"], + ["/layer1/layer1.0/conv1/Conv_output_0"], + name="/layer1/layer1.0/conv1/Conv", + dilations=[1, 1], + group=1, + kernel_shape=[1, 1], + pads=[0, 0, 0, 0], + strides=[1, 1], + domain="", + ) + nodes.append(node) + + node = make_node( + "Relu", + ["/layer1/layer1.0/conv1/Conv_output_0"], + ["/layer1/layer1.0/relu/Relu_output_0"], + name="/layer1/layer1.0/relu/Relu", + domain="", + ) + nodes.append(node) + + # opsets + opset_imports = [make_opsetid(domain, 1 if version is None else version) for domain, version in opsets.items()] + + # graph + graph = make_graph(nodes, "torch_jit", inputs, outputs, initializers) + # '7' + + onnx_model = make_model(graph, opset_imports=opset_imports, functions=functions) + onnx_model.ir_version = 7 + onnx_model.producer_name = "pytorch" + onnx_model.producer_version = "" + onnx_model.domain = "" + onnx_model.model_version = 0 + onnx_model.doc_string = "" + set_model_props(onnx_model, {}) + + return onnx_model diff --git a/onnxruntime/test/python/quantization/test_op_gavgpool.py b/onnxruntime/test/python/quantization/test_op_gavgpool.py index aa7a1833dd..ea01df6810 100644 --- a/onnxruntime/test/python/quantization/test_op_gavgpool.py +++ b/onnxruntime/test/python/quantization/test_op_gavgpool.py @@ -103,10 +103,10 @@ def quantize_gavgpool_test(self, activation_type, weight_type, extra_options={}) quant_nodes = { "QLinearConv": 1, - "GlobalAveragePool": 1, - "QLinearGlobalAveragePool": 1, - "QuantizeLinear": 1, - "DequantizeLinear": 1, + "GlobalAveragePool": 0, + "QLinearGlobalAveragePool": 2, + "QuantizeLinear": 2, + "DequantizeLinear": 2, } check_op_type_count(self, model_q8_path, **quant_nodes) qnode_io_qtypes = { diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py new file mode 100644 index 0000000000..02f51cc4fa --- /dev/null +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import tempfile +import unittest +from importlib.util import find_spec +from pathlib import Path +from typing import Dict, Tuple, Union + +import numpy as np +import onnx +from onnx import TensorProto, helper +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count + +from onnxruntime.quantization import quant_utils + + +class TestOpMatMul4Bits(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="test_matmul4bits.") + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def fill_int4_data(self, shape: Union[int, Tuple[int, ...]], symmetric: bool) -> np.ndarray: + line = np.zeros(shape) + line = line.reshape(-1) + + if symmetric: + v = -2.0 + for i in range(line.shape[0]): + if v == 0 or v == -3 or v == 3: + v += 1 + line[i] = v + v += 1 + if v >= 8: + v = -8 + else: + v = 0.0 + for i in range(line.shape[0]): + line[i] = v + v += 1 + if v >= 16: + v = 0 + + return line.reshape(shape) + + def input_feeds(self, n: int, name2shape: Dict[str, Union[int, Tuple[int, ...]]]) -> TestDataFeeds: + input_data_list = [] + for _i in range(n): + inputs = {} + for name, shape in name2shape.items(): + inputs.update({name: np.random.randint(-1, 2, shape).astype(np.float32)}) + input_data_list.extend([inputs]) + dr = TestDataFeeds(input_data_list) + return dr + + def construct_model_matmul(self, output_model_path: str, symmetric: bool) -> None: + # (input) + # | + # MatMul + # | + # (output) + input_name = "input" + output_name = "output" + initializers = [] + + def make_matmul(input_name, weight_shape: Union[int, Tuple[int, ...]], weight_name: str, output_name: str): + weight_data = self.fill_int4_data(weight_shape, symmetric).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name)) + return onnx.helper.make_node( + "MatMul", + [input_name, weight_name], + [output_name], + ) + + in_features = 52 + out_features = 288 + # make MatMul node + matmul_node = make_matmul( + input_name, + [in_features, out_features], + "linear1.weight", + output_name, + ) + + # make graph + input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, [-1, in_features]) + output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, [-1, out_features]) + graph_name = "matmul_4bits_test" + graph = helper.make_graph( + [matmul_node], + graph_name, + [input_tensor], + [output_tensor], + initializer=initializers, + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + model.ir_version = 7 # use stable onnx ir version + + onnx.save(model, output_model_path) + + def quant_test( + self, + model_fp32_path: str, + data_reader: TestDataFeeds, + block_size: int, + is_symmetric: bool, + ): + model_int4_path = str( + Path(self._tmp_model_dir.name).joinpath(f"MatMulNBits_{block_size}_{is_symmetric}.onnx").absolute() + ) + + # Quantize fp32 model to int4 model + from onnxruntime.quantization import matmul_4bits_quantizer + + model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) + quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric) + quant.process() + quant.model.save_model_to_file(model_int4_path, False) + + quant_nodes = {"MatMulNBits": 1} + check_op_type_count(self, model_int4_path, **quant_nodes) + + data_reader.rewind() + + try: + check_model_correctness(self, model_fp32_path, model_int4_path, data_reader.get_next()) + except Exception as exception: + if "4b quantization not yet supported on this hardware platform!" in exception.args[0]: + # Currently we don't have int4 quantization support on all platforms, has to tolerate this exception + pass + else: + raise exception + + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + ) + def test_quantize_matmul_int4_symmetric(self): + np.random.seed(13) + + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_symmetric.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, symmetric=True) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + self.quant_test(model_fp32_path, data_reader, 32, True) + + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + ) + def test_quantize_matmul_int4_offsets(self): + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, symmetric=False) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + self.quant_test(model_fp32_path, data_reader, 32, False) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_matmul_bnb4.py b/onnxruntime/test/python/quantization/test_op_matmul_bnb4.py new file mode 100644 index 0000000000..88432d75c6 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_op_matmul_bnb4.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import tempfile +import unittest +from importlib.util import find_spec +from pathlib import Path +from typing import Dict, Tuple, Union + +import numpy as np +import onnx +from onnx import TensorProto, helper +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count + +from onnxruntime.quantization import quant_utils + +quant_maps = { + 0: [ + 0.00000000, + 5.208333333e-03, + 0.66666667, + 1.00000000, + 0.33333333, + 0.50000000, + 0.16666667, + 0.25000000, + -0.00000000, + -5.208333333e-03, + -0.66666667, + -1.00000000, + -0.33333333, + -0.50000000, + -0.16666667, + -0.25000000, + ], + 1: [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ], +} + + +class TestOpMatMulBnb4(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="test_matmulbnb4.") + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def fill_bnb4_data(self, shape: Tuple[int, int], quant_type: int) -> np.ndarray: + rows, cols = shape + line = np.zeros(shape) + line = line.reshape(-1) + quant_map = np.array(quant_maps[quant_type], dtype=np.float32) + + v = 0 + for i in range(line.shape[0]): + line[i] = quant_map[v] + v += 1 + if v >= 16: + v = 0 + + # bnb quantization quantizes weight.T after flattening + line = line.reshape(cols, rows).transpose() + return line.reshape(shape) + + def input_feeds(self, n: int, name2shape: Dict[str, Union[int, Tuple[int, ...]]]) -> TestDataFeeds: + input_data_list = [] + for _i in range(n): + inputs = {} + for name, shape in name2shape.items(): + inputs.update({name: np.random.randint(-1, 2, shape).astype(np.float32)}) + input_data_list.extend([inputs]) + dr = TestDataFeeds(input_data_list) + return dr + + def construct_model_matmul(self, output_model_path: str, quant_type: int) -> None: + # (input) + # | + # MatMul + # | + # (output) + input_name = "input" + output_name = "output" + initializers = [] + + def make_matmul(input_name, weight_shape: Union[int, Tuple[int, ...]], weight_name: str, output_name: str): + weight_data = self.fill_bnb4_data(weight_shape, quant_type).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name)) + return onnx.helper.make_node( + "MatMul", + [input_name, weight_name], + [output_name], + ) + + # for this to work (in_features * out_features) % block_size == 0 + in_features = 52 + out_features = 288 + # make MatMul node + matmul_node = make_matmul( + input_name, + [in_features, out_features], + "linear1.weight", + output_name, + ) + + # make graph + input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, [-1, in_features]) + output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, [-1, out_features]) + graph_name = "matmul_bnb4_test" + graph = helper.make_graph( + [matmul_node], + graph_name, + [input_tensor], + [output_tensor], + initializer=initializers, + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + model.ir_version = 7 # use stable onnx ir version + + onnx.save(model, output_model_path) + + def quant_test(self, quant_type: int, block_size: int): + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath(f"matmul_fp32_{quant_type}.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, quant_type) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + + model_bnb4_path = str( + Path(self._tmp_model_dir.name).joinpath(f"MatMulBnb4_{quant_type}_{block_size}.onnx").absolute() + ) + + # Quantize fp32 model to bnb4 model + from onnxruntime.quantization import matmul_bnb4_quantizer + + model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) + quant = matmul_bnb4_quantizer.MatMulBnb4Quantizer(model, quant_type, block_size) + quant.process() + quant.model.save_model_to_file(model_bnb4_path, False) + + quant_nodes = {"MatMulBnb4": 1} + check_op_type_count(self, model_bnb4_path, **quant_nodes) + + data_reader.rewind() + + try: + check_model_correctness(self, model_fp32_path, model_bnb4_path, data_reader.get_next()) + except Exception as exception: + raise exception + + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_bnb4" + ) + def test_quantize_matmul_bnb4_fp4(self): + np.random.seed(13) + self.quant_test(0, 64) + + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_bnb4" + ) + def test_quantize_matmul_bnb4_nf4(self): + np.random.seed(13) + self.quant_test(1, 64) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/quantization/test_quant_shape_inference.py b/onnxruntime/test/python/quantization/test_quant_shape_inference.py new file mode 100644 index 0000000000..015ab04d2c --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quant_shape_inference.py @@ -0,0 +1,370 @@ +#!/usr/bin/env python +# coding: utf-8 +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import tempfile +import unittest + +import numpy as np +from onnx import TensorProto, helper, numpy_helper + +from onnxruntime.tools import symbolic_shape_infer + + +class TestQLinearOpsShapeInfer(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="test_quant_shape_infer.") + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def get_model(self, node, input_shape, initializer): + # Create a single node model + return helper.make_model( + opset_imports=[ + helper.make_operatorsetid("", 12), + helper.make_operatorsetid("com.microsoft", 1) + ], + graph=helper.make_graph( + name="qlinear_test", + inputs=[helper.make_tensor_value_info("input", TensorProto.INT8, shape=input_shape)], + outputs=[helper.make_tensor_value_info("output", TensorProto.INT8, shape=[])], + initializer=initializer, + value_info=[], + nodes=[node] + ), + ) + + def infer_out_shape(self, model): + inf_onnx = symbolic_shape_infer.SymbolicShapeInference.infer_shapes( + in_mp=model, + auto_merge=True, + int_max=100000, + guess_output_rank=True, + ) + out_shape_proto = inf_onnx.graph.value_info[-1].type.tensor_type.shape + return [sh.dim_value for sh in out_shape_proto.dim] + + def test_shape_qlinear_add(self): + model = self.get_model( + helper.make_node( + "QLinearAdd", + inputs=[ + "input", + "input_scale", + "input_zero_point", + "add_bias_quantized", + "add_bias_scale", + "add_bias_zero_point", + "add_out_scale", + "add_out_zero_point", + ], + outputs=["output"], + name="quant_node", + domain="com.microsoft"), + [1, 8, 14, 14], + [ + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="input_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="input_zero_point"), + numpy_helper.from_array(np.ones([8, 1, 1]).astype("int8"), name="add_bias_quantized"), + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="add_bias_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="add_bias_zero_point"), + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="add_out_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="add_out_zero_point"), + ] + ) + self.assertEqual(self.infer_out_shape(model), + [1, 8, 14, 14], + "Wrong shape inferred for quantized network output") + + def test_shape_qlinear_mult(self): + model = self.get_model( + helper.make_node( + "QLinearMul", + inputs=[ + "mul_bias_quantized", + "mul_bias_scale", + "mul_bias_zero_point", + "input", + "input_scale", + "input_zero_point", + "mul_out_scale", + "mul_out_zero_point", + ], + outputs=["output"], + name="quant_node", + domain="com.microsoft"), + [1, 8, 14, 14], + [ + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="input_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="input_zero_point"), + numpy_helper.from_array(np.ones([8, 1, 1]).astype("int8"), name="mul_bias_quantized"), + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="mul_bias_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="mul_bias_zero_point"), + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="mul_out_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="mul_out_zero_point"), + ] + ) + self.assertEqual(self.infer_out_shape(model), + [1, 8, 14, 14], + "Wrong shape inferred for quantized network output") + + def test_shape_qlinear_concat(self): + model = self.get_model( + helper.make_node( + "QLinearConcat", + inputs=[ + "concat_out_scale", + "concat_out_zero_point", + "input", + "input_scale", + "input_zero_point", + "concat_bias_quantized", + "concat_out_scale", + "concat_out_zero_point", + ], + outputs=["output"], + name="quant_node", + domain="com.microsoft", + axis=1, + ), + [1, 8, 14, 14], + [ + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="input_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="input_zero_point"), + numpy_helper.from_array(np.ones([1, 8, 14, 14]).astype("int8"), name="concat_bias_quantized"), + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="concat_bias_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="concat_bias_zero_point"), + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="concat_out_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="concat_out_zero_point"), + ] + ) + self.assertEqual(self.infer_out_shape(model), + [1, 16, 14, 14], + "Wrong shape inferred for quantized network output") + + def test_shape_qlinear_leaky_relu(self): + model = self.get_model( + helper.make_node( + "QLinearLeakyRelu", + inputs=[ + "input", + "input_scale", + "input_zero_point", + "relu_out_scale", + "relu_out_zero_point", + ], + outputs=["output"], + name="quant_node", + domain="com.microsoft", + alpha=0.009999999776482582, + ), + [1, 16, 14, 14], + [ + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="input_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="input_zero_point"), + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="relu_out_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="relu_out_zero_point"), + ] + ) + self.assertEqual(self.infer_out_shape(model), + [1, 16, 14, 14], + "Wrong shape inferred for quantized network output") + + def test_shape_qlinear_average_pool(self): + model = self.get_model( + helper.make_node( + "QLinearAveragePool", + inputs=[ + "input", + "input_scale", + "input_zero_point", + "pool_out_scale", + "pool_out_zero_point", + ], + outputs=["output"], + name="quant_node", + domain="com.microsoft", + kernel_shape=[2, 2], + strides=[2, 2], + ), + [1, 16, 14, 14], + [ + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="input_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="input_zero_point"), + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="pool_out_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="pool_out_zero_point"), + ] + ) + self.assertEqual(self.infer_out_shape(model), + [1, 16, 7, 7], + "Wrong shape inferred for quantized network output") + + def test_shape_qlinear_sigmoid(self): + model = self.get_model( + helper.make_node( + "QLinearSigmoid", + inputs=[ + "input", + "input_scale", + "input_zero_point", + "sigmoid_out_scale", + "sigmoid_out_zero_point", + ], + outputs=["output"], + name="quant_node", + domain="com.microsoft", + ), + [1, 16, 7, 7], + [ + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="input_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="input_zero_point"), + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="sigmoid_out_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="sigmoid_out_zero_point"), + ] + ) + self.assertEqual(self.infer_out_shape(model), + [1, 16, 7, 7], + "Wrong shape inferred for quantized network output") + + def test_shape_qlinear_global_average_pool(self): + model = self.get_model( + helper.make_node( + "QLinearGlobalAveragePool", + inputs=[ + "input", + "input_scale", + "input_zero_point", + "gap_out_scale", + "gap_out_zero_point", + ], + outputs=["output"], + name="quant_node", + domain="com.microsoft", + channels_last=0, + ), + [1, 16, 7, 7], + [ + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="input_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="input_zero_point"), + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="gap_out_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="gap_out_zero_point"), + ] + ) + self.assertEqual(self.infer_out_shape(model), + [1, 16, 1, 1], + "Wrong shape inferred for quantized network output") + + def test_shape_qlinear_gemm(self): + model = self.get_model( + helper.make_node( + "QGemm", + inputs=[ + "input", + "input_scale", + "input_zero_point", + "gemm_wt_quantized", + "gemm_wt_scale", + "gemm_wt_zero_point", + "gemm_bias_quantized", + "gemm_out_scale", + "gemm_out_zero_point", + ], + outputs=["output"], + name="quant_node", + domain="com.microsoft", + alpha=1.0, + transB=1, + ), + [1, 16], + [ + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="input_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="input_zero_point"), + numpy_helper.from_array(np.ones([32, 16]).astype("int8"), name="gemm_wt_quantized"), + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="gemm_wt_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="gemm_wt_zero_point"), + numpy_helper.from_array(np.ones([32]).astype("int32"), name="gemm_bias_quantized"), + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="gemm_out_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="gemm_out_zero_point"), + ] + ) + self.assertEqual(self.infer_out_shape(model), + [1, 32], + "Wrong shape inferred for quantized network output") + + def test_shape_qlinear_softmax(self): + model = self.get_model( + helper.make_node( + "QLinearSoftmax", + inputs=[ + "input", + "input_scale", + "input_zero_point", + "softmax_out_scale", + "softmax_out_zero_point", + ], + outputs=["output"], + name="quant_node", + domain="com.microsoft", + ), + [1, 32], + [ + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="input_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="input_zero_point"), + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="softmax_out_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="softmax_out_zero_point"), + ] + ) + self.assertEqual(self.infer_out_shape(model), + [1, 32], + "Wrong shape inferred for quantized network output") + + def test_shape_qlinear_conv_transpose(self): + model = self.get_model( + helper.make_node( + "QLinearConvTranspose", + inputs=[ + "input", + "input_scale", + "input_zero_point", + "conv_transpose_wt_quantized", + "weight_scale", + "weight_zero_point", + "conv_transpose_out_scale", + "conv_transpose_out_zero_point", + "conv_transpose_bias" + ], + outputs=["output"], + name="quant_node", + domain="com.microsoft", + auto_pad=b'NOTSET', + dilations=[1, 1], + group=1, + kernel_shape=[2, 2], + pads=[0, 0, 0, 0], + strides=[2, 2] + ), + [1, 32, 14, 14], + [ + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="input_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="input_zero_point"), + numpy_helper.from_array(np.ones([32, 64, 2, 2]).astype("int8"), name="conv_transpose_wt_quantized"), + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="weight_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="weight_zero_point"), + numpy_helper.from_array(np.array(0.007874015718698502, dtype="float32"), name="conv_transpose_out_scale"), + numpy_helper.from_array(np.array(0, dtype="int8"), name="conv_transpose_out_zero_point"), + numpy_helper.from_array(np.ones([64]).astype("int32"), name="conv_transpose_bias"), + ] + ) + self.assertEqual(self.infer_out_shape(model), + [1, 64, 28, 28], + "Wrong shape inferred for quantized network output") + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/quantization/test_quantize_static_resnet.py b/onnxruntime/test/python/quantization/test_quantize_static_resnet.py new file mode 100644 index 0000000000..1efa283af6 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quantize_static_resnet.py @@ -0,0 +1,138 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import os +import random +import tempfile +import unittest + +import numpy as np +import onnx +from numpy.testing import assert_allclose +from onnx.numpy_helper import to_array +from resnet_code import create_model + +from onnxruntime import InferenceSession +from onnxruntime import __version__ as ort_version +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static +from onnxruntime.quantization.calibrate import CalibrationDataReader, CalibrationMethod + + +class FakeResnetCalibrationDataReader(CalibrationDataReader): + def __init__(self, batch_size: int = 16): + super().__init__() + self.dataset = [ + (np.random.rand(1, 3, 32, 32).astype(np.float32), random.randint(0, 9)) for _ in range(batch_size) + ] + self.iterator = iter(self.dataset) + + def get_next(self) -> dict: + try: + return {"input": next(self.iterator)[0]} + except Exception: + return None + + +class TestStaticQuantizationResNet(unittest.TestCase): + def test_quantize_static_resnet(self): + kwargs = { + "activation_type": QuantType.QUInt8, + "weight_type": QuantType.QInt8, + "calibrate_method": CalibrationMethod.Percentile, + "extra_options": { + "ActivationSymmetric": False, + "EnableSubgraph": False, + "ForceQuantizeNoInputCheck": False, + "MatMulConstBOnly": False, + "WeightSymmetric": True, + "extra.Sigmoid.nnapi": False, + }, + "nodes_to_exclude": None, + "nodes_to_quantize": None, + "op_types_to_quantize": None, + "per_channel": True, + "quant_format": QuantFormat.QDQ, + "reduce_range": False, + } + + proto = create_model() + + with tempfile.TemporaryDirectory() as temp: + model = os.path.join(temp, "resnet_first_nodes.onnx") + with open(model, "wb") as f: + f.write(proto.SerializeToString()) + + for per_channel in [True, False]: + kwargs["per_channel"] = per_channel + dataloader = FakeResnetCalibrationDataReader(16) + with self.subTest(per_channel=per_channel): + qdq_file = os.path.join( + temp, f"preprocessed-small-qdq-{1 if per_channel else 0}-ort-{ort_version}.onnx" + ) + quantize_static( + model_input=model, + model_output=qdq_file, + calibration_data_reader=dataloader, + use_external_data_format=False, + **kwargs, + ) + + # With onnxruntime==1.15.1, the initializer 'onnx::Conv_504_zero_point' is: + # * uint8(128) if per_channel is False + # * int8([0, 0, ....]) if per_channel is True + # With onnxruntime>1.16.0 + # * uint8(128) if per_channel is False + # * uint8([128, 128, ..., 127, ...]) if per_channel is True + # QLinearConv : zero point of per-channel filter must be same. + # That's why the quantization forces a symmetric quantization into INT8. + # zero_point is guaranted to be zero whatever the channel is. + + with open(qdq_file, "rb") as f: + onx = onnx.load(f) + for init in onx.graph.initializer: + arr = to_array(init) + if ( + arr.dtype == np.int8 + and "zero_point" not in init.name + and not init.name.endswith("quantized") + ): + raise AssertionError( + f"Initializer {init.name!r} has type {arr.dtype} and " + f"shape {arr.shape} but should be {np.uint8}." + ) + + sess = InferenceSession(qdq_file, providers=["CPUExecutionProvider"]) + shape = (1, 3, 32, 32) + size = np.prod(shape) + dummy = (np.arange(size) / float(size)).astype(np.float32).reshape(shape) + got = sess.run(None, {"input": dummy}) + self.assertEqual(got[0].shape, (1, 64, 8, 8)) + self.assertEqual(got[0].dtype, np.float32) + if per_channel: + expected = np.array( + [ + [[1.0862497091293335, 0.9609132409095764], [1.0862497091293335, 0.9191343784332275]], + [[0.7520190477371216, 1.0026921033859253], [1.0444709062576294, 1.0862497091293335]], + [[0.0, 0.0], [0.0, 0.0]], + [[0.0, 0.0], [0.9609132409095764, 0.7937979102134705]], + ], + dtype=np.float32, + ) + assert_allclose(expected, got[0][0, :4, :2, :2], atol=0.2) + else: + expected = np.array( + [ + [[1.428238868713379, 1.2602107524871826], [1.3442248106002808, 1.2182037830352783]], + [[0.8821475505828857, 1.0921826362609863], [1.1341897249221802, 1.1761966943740845]], + [[0.0, 0.0], [0.0, 0.0]], + [[0.0, 0.0], [1.2182037830352783, 1.050175666809082]], + ], + dtype=np.float32, + ) + assert_allclose(expected, got[0][0, :4, :2, :2], atol=0.2) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py b/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py new file mode 100644 index 0000000000..e03a0167d0 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import unittest +from importlib.util import find_spec + +import numpy as np +import numpy.typing as npt + + +def dequantize_blockwise_4bits(quant_values, scale, zero_point, valid_len): + blob_size = quant_values.shape[0] + block_size = blob_size * 2 + + quant_float = np.zeros((block_size), dtype=scale.dtype) + for b in range(blob_size): + v = quant_values[b] + quant_float[2 * b] = ((v & 0xF) - zero_point) * scale if 2 * b < valid_len else 0.0 + quant_float[2 * b + 1] = ((v >> 4) - zero_point) * scale if 2 * b + 1 < valid_len else 0.0 + return quant_float + + +def quantize_blockwise_4bits_ref(matrix_float: npt.ArrayLike, block_size: int, is_symmetric: bool): + if len(matrix_float.shape) != 2: + raise ValueError("Current int4 block quantization only supports 2D tensors!") + rows, cols = matrix_float.shape + + blob_size = block_size // 2 + k_blocks = (rows + block_size - 1) // block_size + padded_rows = k_blocks * block_size + pad_len = padded_rows - rows + matrix_float_padded = matrix_float + if pad_len > 0: + matrix_float_padded = np.pad(matrix_float, ((0, pad_len), (0, 0)), "constant") + + packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") + scales = np.zeros((cols * k_blocks), dtype=matrix_float_padded.dtype) + zero_point = np.full((cols * k_blocks + 1) // 2, 136, dtype="uint8") + + matrix_float_padded = np.transpose(matrix_float_padded) + for n in range(cols): + for k_id in range(0, rows, block_size): + if is_symmetric: + amax_idx = np.argmax(np.abs(matrix_float_padded[n, k_id : k_id + block_size])) + bmax = np.float32(matrix_float_padded[n, k_id + amax_idx]) + scale = bmax / (-8.0) + zp = 8 + else: + vmin = np.min(np.float32(matrix_float_padded[n, k_id : k_id + block_size])) + vmax = np.max(np.float32(matrix_float_padded[n, k_id : k_id + block_size])) + vmin = min(vmin, 0.0) + vmax = max(vmax, 0.0) + scale = (vmax - vmin) / ((1 << 4) - 1) + zero_point_fp = vmin + if scale != 0.0: + zero_point_fp = 0.0 - vmin / scale + zp = min(15, max(0, round(zero_point_fp))) + + reciprocal_scale = 1.0 / scale if scale != 0 else 0.0 + block_idx = n * k_blocks + k_id // block_size + scales[block_idx] = scale + zp_pair = zero_point[block_idx // 2] + zero_point[block_idx // 2] = ((zp_pair & 0x0F) | (zp << 4)) if (block_idx & 1) else ((zp_pair & 0xF0) | zp) + + blk_int0 = np.clip( + np.round(np.float32(matrix_float_padded[n, k_id : k_id + block_size : 2] * reciprocal_scale + zp)), + 0, + 15, + ).astype("uint8") + blk_int1 = np.clip( + np.round(np.float32(matrix_float_padded[n, k_id + 1 : k_id + block_size : 2] * reciprocal_scale + zp)), + 0, + 15, + ).astype("uint8") + packed[n, k_id // block_size] = np.bitwise_or(blk_int0, np.left_shift(blk_int1, 4)) + + return (packed, scales, zero_point) + + +def quantize_blockwise_4bits_target(matrix_float: npt.ArrayLike, block_size: int, is_symmetric: bool): + if len(matrix_float.shape) != 2: + raise ValueError("Current int4 block quantization only supports 2D tensors!") + rows, cols = matrix_float.shape + + k_blocks = (rows + block_size - 1) // block_size + packed = np.zeros((cols, k_blocks, block_size // 2), dtype="uint8") + scales = np.zeros((cols * k_blocks), dtype=matrix_float.dtype) + zero_point = np.full((cols * k_blocks + 1) // 2, 136, dtype="uint8") + from onnxruntime.capi._pybind_state import quantize_matmul_4bits + + quantize_matmul_4bits(packed, matrix_float, scales, zero_point, block_size, cols, rows, is_symmetric) + return (packed, scales, zero_point) + + +class TestQuantizeBlockwise4Bits(unittest.TestCase): + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + ) + def test_quantize_blockwise_4bits(self): + for rows, cols in [(128, 128), (32, 128), (128, 32), (52, 128), (128, 52), (73, 123)]: + for block_size in [16, 32, 64, 128]: + for type in [np.float32, np.float16]: + for is_symmetric in [True, False]: + matrix_float = np.random.rand(rows, cols).astype(type) + quant_value_ref, scales_ref, zero_point_ref = quantize_blockwise_4bits_ref( + matrix_float, block_size, is_symmetric + ) + quant_value, scales, zero_point = quantize_blockwise_4bits_target( + matrix_float, block_size, is_symmetric + ) + assert np.allclose(scales_ref, scales) + assert np.allclose(zero_point_ref, zero_point) + for c in range(quant_value_ref.shape[0]): + for k in range(quant_value_ref.shape[1]): + block_idx = c * quant_value_ref.shape[1] + k + zp_idx = block_idx // 2 + assert np.allclose( + dequantize_blockwise_4bits( + quant_value_ref[c][k], + scales_ref[block_idx], + (zero_point_ref[zp_idx] >> 4) + if (block_idx & 1) + else (zero_point_ref[zp_idx] & 0x0F), + min(block_size, rows - k * block_size), + ), + dequantize_blockwise_4bits( + quant_value[c][k], + scales[block_idx], + (zero_point[zp_idx] >> 4) if (block_idx & 1) else (zero_point[zp_idx] & 0x0F), + min(block_size, rows - k * block_size), + ), + atol=1.2 * abs(scales[block_idx]), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/quantization/test_quantizeblockwise_bnb4.py b/onnxruntime/test/python/quantization/test_quantizeblockwise_bnb4.py new file mode 100644 index 0000000000..9e9d05fae0 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quantizeblockwise_bnb4.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import unittest +from importlib.util import find_spec + +import numpy as np +import numpy.typing as npt + +quant_enums = {"FP4": 0, "NF4": 1} + + +def quantize_block_fp4(block: npt.ArrayLike): + # quantize a block of float32 values to uint8 by simulating a binary search using pivots + # could have used (block[:,None] - quant_map).argmin(axis=1) but there are some mismatches due to + # floating point precision + # block: 1-D array of normalized [-1,1] float32 values, len(block) % 2 == 0 + + # pivots to find the quantization index + # only half of the pivots are needed since the other half is symmetric + pivots = np.array( + [0.00260417, 0.0859375, 0.20833333, 0.29166667, 0.4166667, 0.583333, 0.8333333, 1], dtype=np.float32 + ) + # indices are not 0,1,2,3,4,5,6,7 because it is a floating point data type + pivot_indices = np.array([0, 1, 6, 7, 4, 5, 2, 3], dtype=np.uint8) + + # signs of the block + signs = (block < 0).astype(np.uint8) * 8 + + # find the uint8 quantization index + # argmax finds the first occurrence of True + quant_indices = pivot_indices[(np.abs(block)[:, None] <= pivots).argmax(axis=1)] + signs + + return np.bitwise_or(np.left_shift(quant_indices[::2], 4), quant_indices[1::2]) + + +def quantize_block_nf4(block: npt.ArrayLike): + pivots = np.array( + [ + -0.8480964004993439, + -0.6106329262256622, + -0.4599952697753906, + -0.33967943489551544, + -0.23460740596055984, + -0.13791173323988914, + -0.045525018125772476, + 0.03979014977812767, + 0.1202552504837513, + 0.2035212516784668, + 0.2920137718319893, + 0.3893125355243683, + 0.5016634166240692, + 0.6427869200706482, + 0.8614784181118011, + 1.0, + ], + dtype=np.float32, + ) + + quant_indices = (block[:, None] <= pivots).argmax(axis=1).astype(np.uint8) + + return np.bitwise_or(np.left_shift(quant_indices[::2], 4), quant_indices[1::2]) + + +def quantize_blockwise_bnb4_ref(matrix_float: npt.ArrayLike, block_size: int, quant_type: str, target=None): + if len(matrix_float.shape) != 2: + raise ValueError("Current bnb4 block quantization only supports 2D tensors!") + + numel = matrix_float.size + num_blocks = (numel + block_size - 1) // block_size + quantized_numel = (numel + 1) // 2 + + packed = np.zeros(quantized_numel, dtype=np.uint8) + absmax = np.zeros(num_blocks, dtype=matrix_float.dtype) + + flattened_matrix_float = matrix_float.flatten() + for block_idx in range(num_blocks): + block_len = min(block_size, numel - block_idx * block_size) + block = np.float32(flattened_matrix_float[block_idx * block_size : block_idx * block_size + block_len]) + + block_absmax = np.max(np.abs(block)) + reciprocal_absmax = 1.0 / block_absmax if block_absmax != 0 else 0.0 + absmax[block_idx] = block_absmax + + if block_len % 2 != 0: + block = np.append(block, 0.0) + block_len += 1 + + block *= reciprocal_absmax + start = block_idx * block_size // 2 + end = start + block_len // 2 + if quant_type == "FP4": + packed[start:end] = quantize_block_fp4(block) + else: + packed[start:end] = quantize_block_nf4(block) + + return (packed, absmax) + + +def quantize_blockwise_bnb4_target(matrix_float: npt.ArrayLike, block_size: int, quant_type: str): + if len(matrix_float.shape) != 2: + raise ValueError("Current int4 block quantization only supports 2D tensors!") + quant_type_enum = quant_enums[quant_type] + + n, k = matrix_float.shape # already transposed + numel = n * k + num_blocks = (numel + block_size - 1) // block_size + quantized_numel = (numel + 1) // 2 + + packed = np.zeros(quantized_numel, dtype="uint8") + absmax = np.zeros(num_blocks, dtype=matrix_float.dtype) + from onnxruntime.capi._pybind_state import quantize_matmul_bnb4 + + quantize_matmul_bnb4(packed, matrix_float, absmax, block_size, quant_type_enum, n, k) + return (packed, absmax) + + +class TestQuantizeBlockwiseBnb4(unittest.TestCase): + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_bnb4" + ) + def test_quantize_blockwise_bnb4(self): + for quant_type in ["FP4", "NF4"]: + for k, n in [(128, 128), (32, 128), (128, 32), (52, 128), (128, 52), (73, 123)]: + for block_size in [16, 32, 64, 128]: + for type in [np.float32, np.float16]: + matrix_float = np.random.uniform(-1, 1, (k, n)).astype(type) + quant_value_ref, absmax_ref = quantize_blockwise_bnb4_ref(matrix_float, block_size, quant_type) + quant_value, absmax = quantize_blockwise_bnb4_target(matrix_float, block_size, quant_type) + assert np.allclose(quant_value_ref, quant_value) + assert np.allclose(absmax_ref, absmax) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/benchmark_gqa.py b/onnxruntime/test/python/transformers/benchmark_gqa.py new file mode 100644 index 0000000000..a9bef025a7 --- /dev/null +++ b/onnxruntime/test/python/transformers/benchmark_gqa.py @@ -0,0 +1,339 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +""" +Benchmark performance of MultiHeadAttention with Nvidia GPU of Compute Capability 8.0, 8.6 or 8.9 in Linux: +sh benchmark_mha.sh +""" + +import math +import random +import statistics +import time + +import torch +from onnx import TensorProto, helper + +from onnxruntime import InferenceSession, OrtValue, SessionOptions + + +class InputFormats: + QKV_BSNH = 0 + QKV_BNSH = 1 + + +class Config: + batch_size = 0 + sequence_length = 0 + kv_sequence_length = 0 + past_sequence_length = 0 + num_heads = 0 + kv_num_heads = 0 + head_size = 0 + + def __init__(self, b, s, s2, sp, n, n2, h): + self.batch_size = b + self.sequence_length = s + self.kv_sequence_length = s2 + self.past_sequence_length = sp + self.num_heads = n + self.kv_num_heads = n2 + self.head_size = h + + +def create_group_query_attention_graph_past( + config, causal=False, past_kv_format=InputFormats.QKV_BSNH, share_buffer=True +): + past_kv_seqlen = config.kv_sequence_length if share_buffer else config.past_sequence_length + present_kv_seqlen = ( + config.kv_sequence_length if share_buffer else config.past_sequence_length + config.sequence_length + ) + nodes = [ + helper.make_node( + "GroupQueryAttention", + [ + "query", + "key", + "value", + "past_key", + "past_value", + "past_sequence_length" if share_buffer else "", + ], + ["output", "present_key", "present_value"], + "GroupQueryAttention_0", + num_heads=config.num_heads, + kv_num_heads=config.kv_num_heads, + unidirectional=1 if causal else 0, + is_past_bsnh=1 if past_kv_format == InputFormats.QKV_BSNH else 0, + domain="com.microsoft", + ), + ] + + graph_input = [ + helper.make_tensor_value_info( + "query", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "past_key", + TensorProto.FLOAT16, + [ + config.batch_size, + past_kv_seqlen if past_kv_format == InputFormats.QKV_BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == InputFormats.QKV_BSNH else past_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "past_value", + TensorProto.FLOAT16, + [ + config.batch_size, + past_kv_seqlen if past_kv_format == InputFormats.QKV_BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == InputFormats.QKV_BSNH else past_kv_seqlen, + config.head_size, + ], + ), + ] + if share_buffer: + graph_input += [ + helper.make_tensor_value_info( + "past_sequence_length", + TensorProto.INT32, + [1], + ) + ] + + graph_output = [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT16, + [config.batch_size, config.sequence_length, config.num_heads * config.head_size], + ), + helper.make_tensor_value_info( + "present_key", + TensorProto.FLOAT16, + [ + config.batch_size, + present_kv_seqlen if past_kv_format == InputFormats.QKV_BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == InputFormats.QKV_BSNH else present_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "present_value", + TensorProto.FLOAT16, + [ + config.batch_size, + present_kv_seqlen if past_kv_format == InputFormats.QKV_BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == InputFormats.QKV_BSNH else present_kv_seqlen, + config.head_size, + ], + ), + ] + + graph = helper.make_graph( + nodes, + "GroupQueryAttention_Graph", + graph_input, + graph_output, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def create_gqa_session( + config: Config, + causal: bool = False, + past_format=InputFormats.QKV_BSNH, + share_buffer: bool = True, +) -> InferenceSession: + onnx_model_str = create_group_query_attention_graph_past(config, causal, past_format, share_buffer) + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) + return ort_session + + +def bind_io(io_binding, input_dict, device, share_buffer=True): + io_binding.bind_cpu_input("query", input_dict["query"]) + io_binding.bind_cpu_input("key", input_dict["key"]) + io_binding.bind_cpu_input("value", input_dict["value"]) + io_binding.bind_input( + "past_key", "cuda", 0, "float16", input_dict["past_key"].shape(), input_dict["past_key"].data_ptr() + ) + io_binding.bind_input( + "past_value", + "cuda", + 0, + "float16", + input_dict["past_value"].shape(), + input_dict["past_value"].data_ptr(), + ) + io_binding.bind_output("output") + if share_buffer: + io_binding.bind_cpu_input("past_sequence_length", input_dict["past_sequence_length"]) + io_binding.bind_output( + "present_key", + device_type="cuda", + device_id=device, + element_type="float16", + shape=input_dict["past_key"].shape(), + buffer_ptr=input_dict["past_key"].data_ptr(), + ) + io_binding.bind_output( + "present_value", + device_type="cuda", + device_id=device, + element_type="float16", + shape=input_dict["past_value"].shape(), + buffer_ptr=input_dict["past_value"].data_ptr(), + ) + else: + io_binding.bind_output("present_key") + io_binding.bind_output("present_value") + + +def measure_latency(ort_session, io_binding): + start = time.time() + _ = ort_session.run_with_iobinding(io_binding) + end = time.time() + return end - start + + +def flops(batch, q_seqlen, kv_seqlen, head_size, num_heads): + return 4 * batch * q_seqlen * kv_seqlen * num_heads * head_size + + +def tflops_per_second(flop, time): + return (flop / time / 10**12) if not math.isnan(time) else 0.0 + + +def benchmark_op(session, io_binding, repeats=100): + # warm up session + _ = measure_latency(session, io_binding) + + latency_list = [] + for _ in range(repeats): + latency = measure_latency(session, io_binding) + latency_list.append(latency) + return statistics.mean(latency_list) + + +def run_tflops_test(dtype=torch.float16, repeats: int = 100): + device_id = torch.cuda.current_device() + device = torch.device("cuda", device_id) + print("---- GQA BSNH vs GQA BNSH ----") + print("op\tbatch\ts_kv\theads\th_dim\tms\tTFLOPS") + mean_bsnh_lat = 0 + mean_bnsh_lat = 0 + num_trials = 0 + share_buffer = True + random.seed(69) + for b in [1, 3, 8, 16]: + for s_q, s_kv in [(1, 128), (128, 256), (512, 512), (128, 1024), (1, 2048)]: + for n_q, n_kv in [(8, 8), (16, 8), (32, 32), (12, 3), (128, 64)]: + for h in [32, 64, 128]: + sp = random.randint(1, s_kv - 1) if s_kv - 1 > 0 else 0 + config = Config(b, s_q, s_kv, sp, n_q, n_kv, h) + + bsnh_session = create_gqa_session( + config, + causal=False, + past_format=InputFormats.QKV_BSNH, + share_buffer=share_buffer, + ) + bnsh_session = create_gqa_session( + config, + causal=False, + past_format=InputFormats.QKV_BNSH, + share_buffer=share_buffer, + ) + + q = torch.randn(b, s_q, n_q * h, device=device, dtype=dtype) + kv = torch.randn(b, s_q, 2, n_kv * h, device=device, dtype=dtype) + k, v = kv.unbind(dim=2) + + past_kv = torch.rand(b, s_kv if share_buffer else sp, 2, n_kv, h, device=device, dtype=dtype) + past_k, past_v = past_kv.unbind(dim=2) + + input_dict_bsnh = { + "query": q.detach().cpu().numpy(), + "key": k.detach().cpu().numpy(), + "value": v.detach().cpu().numpy(), + "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", device_id), + "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", device_id), + } + input_dict_bnsh = { + "query": q.detach().cpu().numpy(), + "key": k.detach().cpu().numpy(), + "value": v.detach().cpu().numpy(), + "past_key": OrtValue.ortvalue_from_numpy( + past_k.transpose(1, 2).detach().cpu().numpy(), "cuda", 0 + ), + "past_value": OrtValue.ortvalue_from_numpy( + past_v.transpose(1, 2).detach().cpu().numpy(), "cuda", 0 + ), + } + if share_buffer: + input_dict_bsnh["past_sequence_length"] = ( + torch.tensor([sp], device="cuda", dtype=torch.int32).detach().cpu().numpy() + ) + input_dict_bnsh["past_sequence_length"] = ( + torch.tensor([sp], device="cuda", dtype=torch.int32).detach().cpu().numpy() + ) + + io_binding_bsnh = bsnh_session.io_binding() + io_binding_bnsh = bnsh_session.io_binding() + bind_io(io_binding_bsnh, input_dict_bsnh, device_id, share_buffer) + bind_io(io_binding_bnsh, input_dict_bnsh, device_id, share_buffer) + average_gqa_bsnh_latency = benchmark_op(bsnh_session, io_binding_bsnh, repeats) + average_gqa_bnsh_latency = benchmark_op(bnsh_session, io_binding_bnsh, repeats) + + del bsnh_session + del bnsh_session + + # compute TFLOPS per second + bsnh_speed = tflops_per_second(flops(b, s_q, s_kv, h, n_q), average_gqa_bsnh_latency) + print(f"bsnh\t{b}\t{s_kv}\t{n_q}\t{h}\t{average_gqa_bsnh_latency * 1000:.2f}\t{bsnh_speed:.2f}") + bnsh_speed = tflops_per_second(flops(b, s_q, s_kv, h, n_q), average_gqa_bnsh_latency) + print(f"bnsh\t{b}\t{s_kv}\t{n_q}\t{h}\t{average_gqa_bnsh_latency * 1000:.2f}\t{bnsh_speed:.2f}") + print("---------") + if average_gqa_bsnh_latency > 10 * average_gqa_bnsh_latency: + continue + num_trials += 1 + mean_bsnh_lat += average_gqa_bsnh_latency + mean_bnsh_lat += average_gqa_bnsh_latency + mean_bsnh_lat /= num_trials + mean_bnsh_lat /= num_trials + print(f"average bsnh latency:\t{mean_bsnh_lat}") + print(f"average bnsh latency:\t{mean_bnsh_lat}") + + +if __name__ == "__main__": + run_tflops_test() diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py new file mode 100644 index 0000000000..1e75268ea6 --- /dev/null +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -0,0 +1,343 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +""" +Benchmark performance of MultiHeadAttention with Nvidia GPU of Compute Capability 8.0, 8.6 or 8.9 in Linux: +sh benchmark_mha.sh +""" + +import math +import os +import statistics +import time + +import torch +from onnx import TensorProto, helper + +from onnxruntime import InferenceSession +from onnxruntime.transformers.io_binding_helper import CudaSession + + +class InputFormats: + Q_K_V_BSNH = 0 + QKV_BSN3H = 1 + Q_KV_BSNH_BSN2H = 2 + + @staticmethod + def input_format_str(format: int) -> str: + return "QKV" if format == 1 else "Q,KV" if format == 2 else "Q,K,V" + + +class Config: + batch_size: int = 0 + sequence_length: int = 0 + kv_sequence_length: int = 0 + num_heads: int = 0 + head_size: int = 0 + causal: bool = False + input_format: int = InputFormats.Q_K_V_BSNH + + def __init__(self, b, s, s2, n, h, causal, input_format): + self.batch_size = b + self.sequence_length = s + self.kv_sequence_length = s2 + self.num_heads = n + self.head_size = h + self.causal = causal + self.input_format = input_format + + +def create_multihead_attention_graph(config: Config): + query = helper.make_tensor_value_info( + "query", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.num_heads * config.head_size, + ], + ) + + key = helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.num_heads * config.head_size, + ], + ) + + value = helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.num_heads * config.head_size, + ], + ) + + packed_qkv = helper.make_tensor_value_info( + "query", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.num_heads, + 3, + config.head_size, + ], + ) + + packed_kv = helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.num_heads, + 2, + config.head_size, + ], + ) + + if config.input_format == InputFormats.QKV_BSN3H: + input_names = ["query"] + inputs = [packed_qkv] + elif config.input_format == InputFormats.Q_KV_BSNH_BSN2H: + input_names = ["query", "key"] + inputs = [query, packed_kv] + else: # input_format==InputFormats.Q_K_V_BSNH + input_names = ["query", "key", "value"] + inputs = [query, key, value] + + nodes = [ + helper.make_node( + "MultiHeadAttention", + input_names, + ["output"], + "MultiHeadAttention_0", + num_heads=config.num_heads, + domain="com.microsoft", + ), + ] + + outputs = [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT16, + [config.batch_size, config.sequence_length, config.num_heads * config.head_size], + ), + ] + + graph = helper.make_graph( + nodes, + "MultiHeadAttention_Graph", + inputs, + outputs, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def input_output_shapes(config: Config): + if config.input_format == InputFormats.QKV_BSN3H: + return { + "query": (config.batch_size, config.sequence_length, config.num_heads, 3, config.head_size), + "output": (config.batch_size, config.sequence_length, config.num_heads * config.head_size), + } + + if config.input_format == InputFormats.Q_KV_BSNH_BSN2H: + return { + "query": (config.batch_size, config.sequence_length, config.num_heads * config.head_size), + "key": (config.batch_size, config.kv_sequence_length, config.num_heads, 2, config.head_size), + "output": (config.batch_size, config.sequence_length, config.num_heads * config.head_size), + } + + return { + "query": (config.batch_size, config.sequence_length, config.num_heads * config.head_size), + "key": (config.batch_size, config.kv_sequence_length, config.num_heads * config.head_size), + "value": (config.batch_size, config.kv_sequence_length, config.num_heads * config.head_size), + "output": (config.batch_size, config.sequence_length, config.num_heads * config.head_size), + } + + +def create_session( + device_id: int, config: Config, provider: str = "CUDAExecutionProvider", enable_cuda_graph: bool = False +) -> CudaSession: + onnx_model_str = create_multihead_attention_graph(config) + provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph) + ort_session = InferenceSession(onnx_model_str, providers=[(provider, provider_options), "CPUExecutionProvider"]) + device = torch.device("cuda", device_id) + cuda_session = CudaSession(ort_session, device, enable_cuda_graph) + shape_dict = input_output_shapes(config) + cuda_session.allocate_buffers(shape_dict) + return cuda_session + + +def measure_latency(cuda_session: CudaSession, input_dict): + start = time.time() + _ = cuda_session.infer(input_dict) + end = time.time() + return end - start + + +def flops(batch, sequence_length, head_size, num_heads, causal): + return 4 * batch * sequence_length**2 * num_heads * head_size // (2 if causal else 1) + + +def tflops_per_second(flop, time): + return (flop / time / 10**12) if not math.isnan(time) else 0.0 + + +def get_sm8x_kernel_name(config: Config) -> str: + # This classification is for Nvidia GPU of Compute Capability 8.* like A100. + # Note that some kernel might not exist in older or newer GPUs. + if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1": + if config.input_format == InputFormats.QKV_BSN3H: + min_seq_len = os.getenv("ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV") + min_length = int(min_seq_len) if min_seq_len is not None else 513 + if config.sequence_length >= min_length: + return "Flash" + else: + return "Flash" + + if (os.getenv("ORT_DISABLE_FUSED_CROSS_ATTENTION") != "1" and config.kv_sequence_length <= 128) or ( + os.getenv("ORT_DISABLE_FUSED_ATTENTION") != "1" + and (config.sequence_length <= 384 or os.getenv("ORT_DISABLE_TRT_FLASH_ATTENTION") != "1") + ): + return "TRT" + + if os.getenv("ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION") != "1": + return "MemEff" + + return "Unfused" + + +def run_tflops_test(dtype=torch.float16, enable_cuda_graph: bool = False, repeats: int = 100): + device_id = torch.cuda.current_device() + device = torch.device("cuda", device_id) + + # (batch_size, sequence_length, num_heads, head_size) + configs = [ + (32, 512, 64, 32), + (32, 512, 128, 16), + (16, 1024, 64, 32), + (16, 1024, 128, 16), + (8, 2048, 64, 32), + (8, 2048, 128, 16), + (4, 4096, 64, 32), + (4, 4096, 128, 16), + (2, 8192, 64, 32), + (2, 8192, 128, 16), + (1, 16384, 64, 32), + (1, 16384, 128, 16), + # stable diffusion + (1, 4096, 8, 40), + (1, 4096, 8, 80), + (1, 4096, 8, 160), + (4, 4096, 8, 40), + (4, 4096, 8, 80), + (4, 4096, 8, 160), + (1, 16384, 8, 40), + (1, 16384, 8, 80), + (1, 16384, 8, 160), + # bert-base + (128, 128, 12, 64), + (64, 128, 12, 64), + (128, 384, 12, 64), + (64, 384, 12, 64), + (128, 512, 12, 64), + (64, 512, 12, 64), + # TNLGv4 + (4, 2048, 32, 128), + (4, 4096, 32, 128), + (8, 2048, 32, 128), + (8, 4096, 32, 128), + ] + + print(f"enable_cuda_graph={enable_cuda_graph}") + + # List of environment variables to enable/disable attention kernels + print("Environment Variables:") + env_names = [ + "ORT_DISABLE_FLASH_ATTENTION", + "ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV", + "ORT_DISABLE_FUSED_ATTENTION", + "ORT_DISABLE_TRT_FLASH_ATTENTION", + "ORT_ENABLE_FUSED_CAUSAL_ATTENTION", + "ORT_DISABLE_FUSED_CROSS_ATTENTION", + "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION", + ] + for name in env_names: + value = os.getenv(name) + if value is not None: + print(f"{name}={value}") + print() + + print("format\tcausal\tbatch\tseqlen\theads\th_dim\tms\tTFLOPS\tkernel") + causal = False + for input_format in [InputFormats.Q_K_V_BSNH, InputFormats.Q_KV_BSNH_BSN2H, InputFormats.QKV_BSN3H]: + for batch_size, sequence_length, num_heads, head_size in configs: + config = Config(batch_size, sequence_length, sequence_length, num_heads, head_size, causal, input_format) + + session = create_session(device_id, config, enable_cuda_graph=enable_cuda_graph) + + qkv = torch.randn(batch_size, sequence_length, 3, num_heads, head_size, device=device, dtype=dtype) + q, k, v = qkv.unbind(dim=2) + + if input_format == InputFormats.QKV_BSN3H: + if config.sequence_length != config.kv_sequence_length: + continue + q = torch.reshape(q, (-1, config.num_heads, config.head_size)) + k = torch.reshape(k, (-1, config.num_heads, config.head_size)) + v = torch.reshape(v, (-1, config.num_heads, config.head_size)) + packed_qkv = torch.dstack((q, k, v)).reshape( + config.batch_size, config.sequence_length, config.num_heads, 3, config.head_size + ) + input_dict = {"query": packed_qkv.contiguous()} + elif input_format == InputFormats.Q_KV_BSNH_BSN2H: + q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) + k = torch.reshape(k, (-1, config.num_heads, config.head_size)) + v = torch.reshape(v, (-1, config.num_heads, config.head_size)) + packed_kv = torch.dstack((k, v)).reshape( + config.batch_size, config.sequence_length, config.num_heads, 2, config.head_size + ) + input_dict = {"query": q.contiguous(), "key": packed_kv.contiguous()} + else: # input_format == InputFormats.Q_K_V_BSNH + q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) + k = torch.reshape(k, (config.batch_size, config.kv_sequence_length, -1)) + v = torch.reshape(v, (config.batch_size, config.kv_sequence_length, -1)) + input_dict = { + "query": q.contiguous(), + "key": k.contiguous(), + "value": v.contiguous(), + } + + # warm up session + _ = measure_latency(session, input_dict) + + latency_list = [] + for _ in range(repeats): + latency = measure_latency(session, input_dict) + latency_list.append(latency) + average_latency = statistics.mean(latency_list) + + del session + + # compute TFLOPS per second + speed = tflops_per_second(flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency) + + kernel = get_sm8x_kernel_name(config) + format = InputFormats.input_format_str(input_format) + print( + f"{format}\t{causal}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t{average_latency * 1000:.2f}\t{speed:.2f}\t{kernel}" + ) + + +if __name__ == "__main__": + run_tflops_test(enable_cuda_graph=False) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.sh b/onnxruntime/test/python/transformers/benchmark_mha.sh new file mode 100644 index 0000000000..7b21cf1cc1 --- /dev/null +++ b/onnxruntime/test/python/transformers/benchmark_mha.sh @@ -0,0 +1,14 @@ +echo "flash attention v2" +ORT_DISABLE_FLASH_ATTENTION=0 ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV=0 python benchmark_mha.py | tee result.txt + +echo "===" +echo "TensorRT attention kernels - cross attention (when kv_seq_len <= 128) or fused attention (when seq_len <= 384) or flash attention (seq_len > 384)" +ORT_DISABLE_FLASH_ATTENTION=1 python benchmark_mha.py | tee -a result.txt + +echo "===" +echo "Memory Efficient attention" +ORT_DISABLE_FLASH_ATTENTION=1 ORT_DISABLE_TRT_FLASH_ATTENTION=1 ORT_DISABLE_FUSED_ATTENTION=1 ORT_DISABLE_FUSED_CROSS_ATTENTION=1 python benchmark_mha.py | tee -a result.txt + +echo "===" +echo "Unfused Attention (some configurations might fail)" +ORT_DISABLE_FLASH_ATTENTION=1 ORT_DISABLE_TRT_FLASH_ATTENTION=1 ORT_DISABLE_FUSED_ATTENTION=1 ORT_DISABLE_FUSED_CROSS_ATTENTION=1 ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION=1 python benchmark_mha.py | tee -a result.txt diff --git a/onnxruntime/test/python/transformers/bert_padding.py b/onnxruntime/test/python/transformers/bert_padding.py new file mode 100644 index 0000000000..a4ef765264 --- /dev/null +++ b/onnxruntime/test/python/transformers/bert_padding.py @@ -0,0 +1,131 @@ +# From https://github.com/Dao-AILab/flash-attention/blob/2286d7cea7ca8264165c16b2442b6436c43140de/flash_attn/bert_padding.py + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + + +class IndexFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + # return input[indices] + return torch.gather(rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)).reshape( + -1, *other_shape + ) + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + grad_output = rearrange(grad_output, "b ... -> b (...)") + grad_input = torch.zeros( + [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype + ) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + # grad_input[indices] = grad_output + grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis = IndexFirstAxis.apply + + +class IndexPutFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, values, indices, first_axis_dim): + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert values.ndim >= 2 + output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + output[indices] = values + # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) + return output + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + grad_values = grad_output[indices] + # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) + return grad_values, None, None + + +index_put_first_axis = IndexPutFirstAxis.apply + + +class IndexFirstAxisResidual(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, _ = input.shape[0], input.shape[1:] + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + output = input[indices] + # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last + # memory format to channel_first. In other words, input might not be contiguous. + # If we don't detach, Pytorch complains about output being a view and is being modified inplace + return output, input.detach() + + @staticmethod + def backward(ctx, grad_output, grad_residual): + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + assert grad_residual.shape[1:] == other_shape + grad_input = grad_residual + # grad_input[indices] += grad_output + indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1))) + indices = indices.expand_as(grad_output) + grad_input.scatter_add_(0, indices, grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis_residual = IndexFirstAxisResidual.apply + + +def unpad_input(hidden_states, attention_mask): + """ + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, + # so we write custom forward and backward to make it a bit faster. + return ( + index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz) + Return: + hidden_states: (batch, seqlen, ...) + """ + # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) + # output[indices] = hidden_states + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return rearrange(output, "(b s) ... -> b s ...", b=batch) diff --git a/onnxruntime/test/python/transformers/gpt2_model_generator.py b/onnxruntime/test/python/transformers/gpt2_model_generator.py index 6d4d6ea920..4a1b48d4d1 100644 --- a/onnxruntime/test/python/transformers/gpt2_model_generator.py +++ b/onnxruntime/test/python/transformers/gpt2_model_generator.py @@ -555,6 +555,8 @@ def create_gpt2_embedlayer( num_heads=4, epsilon=0.1, one_attention_node=False, + has_skip_layer_norm=True, + output_embedding_sum=False, ): # Construct input and output nodes inputs = [ @@ -564,21 +566,47 @@ def create_gpt2_embedlayer( helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size]) ] + if output_embedding_sum: + outputs.append( + helper.make_tensor_value_info( + "embedding_sum", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size] + ) + ) + # Construct graph nodes embed_layernorm_nodes = [ helper.make_node("Gather", ["word_embeddings_weight", "ids"], ["gather_0_out"], "gather_word_embeddings"), helper.make_node("Gather", ["pos_embeddings_weight", "ids"], ["gather_1_out"], "gather_position_embeddings"), helper.make_node("Add", ["gather_0_out", "gather_1_out"], ["add_0_out"], "add_before_layernorm"), helper.make_node("Gather", ["token_embeddings_weight", "ids"], ["gather_2_out"], "gather_token_embeddings"), - helper.make_node( - "SkipLayerNormalization", - ["add_0_out", "gather_2_out", "layernorm_weight", "layernorm_bias"], - ["skip_layernorm_out"], - "skip_layernorm", - domain="com.microsoft", - epsilon=epsilon, - ), ] + + if has_skip_layer_norm: + embed_layernorm_nodes.append( + helper.make_node( + "SkipLayerNormalization", + ["add_0_out", "gather_2_out", "layernorm_weight", "layernorm_bias"], + ["skip_layernorm_out"] if not output_embedding_sum else ["skip_layernorm_out", "", "", "embedding_sum"], + "skip_layernorm", + domain="com.microsoft", + epsilon=epsilon, + ) + ) + else: + embed_layernorm_nodes.append( + helper.make_node("Add", ["add_0_out", "gather_2_out"], ["embedding_sum"], "embedding_sum") + ) + + embed_layernorm_nodes.append( + helper.make_node( + "LayerNormalization", + ["embedding_sum", "layernorm_weight", "layernorm_bias"], + ["skip_layernorm_out"], + "layernorm", + epsilon=epsilon, + ) + ) + attention_nodes = ( [ helper.make_node("MatMul", ["skip_layernorm_out", "q_weight"], ["q_out"], "q_attn"), @@ -708,6 +736,7 @@ def create_gpt2_fused_embedlayer( num_heads=4, epsilon=0.1, one_attention_node=False, + output_embedding_sum=False, ): # Construct input and output nodes inputs = [ @@ -716,6 +745,12 @@ def create_gpt2_fused_embedlayer( outputs = [ helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size]) ] + if output_embedding_sum: + outputs.append( + helper.make_tensor_value_info( + "embedding_sum", TensorProto.FLOAT, ["batch_size", "sequence_length", hidden_size] + ) + ) # Construct graph nodes embed_layernorm_nodes = [ @@ -732,7 +767,9 @@ def create_gpt2_fused_embedlayer( "", "ids", ], - ["EmbedLayerNormalization_0_output", "EmbedLayerNormalization_0_dummy_mask_index"], + ["EmbedLayerNormalization_0_output", "EmbedLayerNormalization_0_dummy_mask_index", "embedding_sum"] + if output_embedding_sum + else ["EmbedLayerNormalization_0_output", "EmbedLayerNormalization_0_dummy_mask_index"], "EmbedLayerNormalization_0", domain="com.microsoft", epsilon=epsilon, @@ -876,3 +913,9 @@ def create_gpt2_fused_embedlayer( model = create_gpt2_fused_embedlayer(one_attention_node=True) onnx.save(model, "./test_data/models/gpt2_embedlayer_one_attn_exp.onnx") + + model = create_gpt2_embedlayer(one_attention_node=True, output_embedding_sum=True) + onnx.save(model, "gpt2_embedlayer_one_attn_output_sum.onnx") + + model = create_gpt2_fused_embedlayer(one_attention_node=True, output_embedding_sum=True) + onnx.save(model, "./test_data/models/gpt2_embedlayer_one_attn_output_sum_exp.onnx") diff --git a/onnxruntime/test/python/transformers/test_attention_fusion.py b/onnxruntime/test/python/transformers/test_attention_fusion.py index 2edc2ec06d..76d1dcf013 100644 --- a/onnxruntime/test/python/transformers/test_attention_fusion.py +++ b/onnxruntime/test/python/transformers/test_attention_fusion.py @@ -31,7 +31,18 @@ def verify_fusion(self, optimized_model, expected_model_filename): expected_model = OnnxModel(onnx.load(expected_model_path)) expected_model.topological_sort(is_deterministic=True) - self.assertEqual(str(optimized_model.model.graph), str(expected_model.model.graph)) + nodes = optimized_model.model.graph.node + self.assertEqual(len(nodes), len(expected_model.model.graph.node)) + + for i in range(len(nodes)): + self.assertEqual(nodes[i], expected_model.model.graph.node[i]) + + for expected_initializer in expected_model.model.graph.initializer: + self.assertTrue( + OnnxModel.has_same_value( + optimized_model.get_initializer(expected_initializer.name), expected_initializer + ) + ) def test_multi_head_attention_fusion(self): model = create_bert_attention() diff --git a/onnxruntime/test/python/transformers/test_data/models/attention_mha.onnx b/onnxruntime/test/python/transformers/test_data/models/attention_mha.onnx index 76d808538e..216f5444d8 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/attention_mha.onnx and b/onnxruntime/test/python/transformers/test_data/models/attention_mha.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/attention_opt.onnx b/onnxruntime/test/python/transformers/test_data/models/attention_opt.onnx index ececb8701a..cc712c61af 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/attention_opt.onnx and b/onnxruntime/test/python/transformers/test_data/models/attention_opt.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/attention_with_varied_qkv_opt.onnx b/onnxruntime/test/python/transformers/test_data/models/attention_with_varied_qkv_opt.onnx index da048bbe5c..25dc71ff51 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/attention_with_varied_qkv_opt.onnx and b/onnxruntime/test/python/transformers/test_data/models/attention_with_varied_qkv_opt.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/bert_3d_attention_opt.onnx b/onnxruntime/test/python/transformers/test_data/models/bert_3d_attention_opt.onnx index fe5384bd4e..53a2809038 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/bert_3d_attention_opt.onnx and b/onnxruntime/test/python/transformers/test_data/models/bert_3d_attention_opt.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_add_opt_no_skiplayernorm.onnx b/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_add_opt_no_skiplayernorm.onnx index 177c29f607..b4ed7169df 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_add_opt_no_skiplayernorm.onnx and b/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_add_opt_no_skiplayernorm.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_add_opt_skiplayernorm.onnx b/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_add_opt_skiplayernorm.onnx index 036d6c1601..0ef3b08319 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_add_opt_skiplayernorm.onnx and b/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_add_opt_skiplayernorm.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_opt_no_skiplayernorm.onnx b/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_opt_no_skiplayernorm.onnx index 7f1174d966..62b51b9dd2 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_opt_no_skiplayernorm.onnx and b/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_opt_no_skiplayernorm.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_opt_skiplayernorm.onnx b/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_opt_skiplayernorm.onnx index ee11024900..0ef3b08319 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_opt_skiplayernorm.onnx and b/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_opt_skiplayernorm.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/gpt2_embedlayer_one_attn_output_sum_exp.onnx b/onnxruntime/test/python/transformers/test_data/models/gpt2_embedlayer_one_attn_output_sum_exp.onnx new file mode 100644 index 0000000000..853f3f5cf7 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/gpt2_embedlayer_one_attn_output_sum_exp.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/gpt2_megatron_opt_no_skiplayernorm.onnx b/onnxruntime/test/python/transformers/test_data/models/gpt2_megatron_opt_no_skiplayernorm.onnx index debd5244ab..8f545aa8b3 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/gpt2_megatron_opt_no_skiplayernorm.onnx and b/onnxruntime/test/python/transformers/test_data/models/gpt2_megatron_opt_no_skiplayernorm.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/gpt2_megatron_opt_skiplayernorm.onnx b/onnxruntime/test/python/transformers/test_data/models/gpt2_megatron_opt_skiplayernorm.onnx index 856d76947a..8d0f1697b5 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/gpt2_megatron_opt_skiplayernorm.onnx and b/onnxruntime/test/python/transformers/test_data/models/gpt2_megatron_opt_skiplayernorm.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/pruned_attention_opt.onnx b/onnxruntime/test/python/transformers/test_data/models/pruned_attention_opt.onnx index 51bf9f08ff..337ede0e4e 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/pruned_attention_opt.onnx and b/onnxruntime/test/python/transformers/test_data/models/pruned_attention_opt.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_attention_with_sln_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_attention_with_sln_fused.onnx index 2fc6a8959d..25265839c8 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_attention_with_sln_fused.onnx and b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_attention_with_sln_fused.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_fused.onnx index 0c5035f7dc..5f21da7e59 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_fused.onnx and b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_fused.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_split_bias_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_split_bias_fused.onnx index 8759d958d3..1da242e19e 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_split_bias_fused.onnx and b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_split_bias_fused.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_fused.onnx index 7b3368d824..e7a201658b 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_fused.onnx and b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_fused.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_split_bias_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_split_bias_fused.onnx index 3c7b613f42..bc72c9b350 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_split_bias_fused.onnx and b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_split_bias_fused.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_fused.onnx index 1119e4c51a..969f20b286 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_fused.onnx and b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_fused.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_split_bias_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_split_bias_fused.onnx index 6a4ee4761a..ca7f33a3f1 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_split_bias_fused.onnx and b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_split_bias_fused.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/encoder_attention_with_sln_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/encoder_attention_with_sln_fused.onnx index 190b70741f..15a178863b 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/encoder_attention_with_sln_fused.onnx and b/onnxruntime/test/python/transformers/test_data/models/whisper/encoder_attention_with_sln_fused.onnx differ diff --git a/onnxruntime/test/python/transformers/test_embedlayer_fusion.py b/onnxruntime/test/python/transformers/test_embedlayer_fusion.py index 732833e5da..ccd367fdbb 100644 --- a/onnxruntime/test/python/transformers/test_embedlayer_fusion.py +++ b/onnxruntime/test/python/transformers/test_embedlayer_fusion.py @@ -74,6 +74,38 @@ def test_embedlayer_fusion_one_attn_node(self): os.remove(original_model_path) os.remove(optimized_model_path) + def test_embedlayer_fusion_with_embedding_sum_output(self): + model = create_gpt2_embedlayer(one_attention_node=True, output_embedding_sum=True) + path = "." + original_model_path = os.path.join(path, "gpt2_embedlayer_one_attn_output_sum.onnx") + optimized_model_path = os.path.join(path, "gpt2_embedlayer_one_attn_output_sum_opt.onnx") + expected_model_filename = "gpt2_embedlayer_one_attn_output_sum_exp.onnx" + + onnx.save(model, original_model_path) + optimized_model = optimize_model(original_model_path, model_type="gpt2") + optimized_model.save_model_to_file(optimized_model_path, use_external_data_format=True) + + self.verify_fusion(optimized_model, expected_model_filename) + self.verify_parity(optimized_model_path, expected_model_filename) + os.remove(original_model_path) + os.remove(optimized_model_path) + + def test_embedlayer_fusion_with_embedding_sum_output_no_sln(self): + model = create_gpt2_embedlayer(one_attention_node=True, has_skip_layer_norm=False, output_embedding_sum=True) + path = "." + original_model_path = os.path.join(path, "gpt2_embedlayer_one_attn_output_sum_no_sln.onnx") + optimized_model_path = os.path.join(path, "gpt2_embedlayer_one_attn_output_sum_no_sln_opt.onnx") + expected_model_filename = "gpt2_embedlayer_one_attn_output_sum_exp.onnx" + + onnx.save(model, original_model_path) + optimized_model = optimize_model(original_model_path, model_type="gpt2") + optimized_model.save_model_to_file(optimized_model_path, use_external_data_format=True) + + self.verify_fusion(optimized_model, expected_model_filename) + self.verify_parity(optimized_model_path, expected_model_filename) + os.remove(original_model_path) + os.remove(optimized_model_path) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py new file mode 100644 index 0000000000..99f62ffdb9 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_flash_attn.py @@ -0,0 +1,1748 @@ +# -------------------------------------------------------------------------- +# Copyright 2020 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import math +import os +import platform +import random +import unittest + +import numpy +import torch +from bert_padding import pad_input, unpad_input +from einops import rearrange, repeat +from onnx import TensorProto, helper + +from onnxruntime import InferenceSession, OrtValue, SessionOptions + +torch.manual_seed(0) + +pipeline_mode = True # Reduces number of tests so pipeline doesn't time out + + +class Formats: + BSNH = 0 + BNSH = 1 + + +class Config: + batch_size = 0 + sequence_length = 0 + kv_sequence_length = 0 + past_sequence_length = 0 + num_heads = 0 + kv_num_heads = 0 + head_size = 0 + + def __init__(self, b, s, s2, sp, n, n2, h): + self.batch_size = b + self.sequence_length = s + self.kv_sequence_length = s2 + self.past_sequence_length = sp + self.num_heads = n + self.kv_num_heads = n2 + self.head_size = h + + +class PromptConfig: + batch_size = 0 + q_sequence_length = 0 + kv_sequence_length = 0 + buffer_sequence_length = 0 + num_heads = 0 + kv_num_heads = 0 + head_size = 0 + + def __init__(self, b, sq, skv, sb, n, n2, h): + self.batch_size = b + self.q_sequence_length = sq + self.kv_sequence_length = skv + self.buffer_sequence_length = sb + self.num_heads = n + self.kv_num_heads = n2 + self.head_size = h + + +def create_packed_multihead_attention_graph(config): + nodes = [ + helper.make_node( + "PackedMultiHeadAttention", + [ + "query", + "", + "", + "", + "token_offset", + "cumulative_sequence_length", + ], + ["output"], + "PackedMultiHeadAttention_0", + num_heads=config.num_heads, + domain="com.microsoft", + ), + ] + + graph = helper.make_graph( + nodes, + "PackedMultiHeadAttention_Graph", + [ + helper.make_tensor_value_info( + "query", + TensorProto.FLOAT16, + [ + -1, + config.num_heads, + 3, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "token_offset", TensorProto.INT32, [config.batch_size, config.sequence_length] + ), + helper.make_tensor_value_info("cumulative_sequence_length", TensorProto.INT32, [config.batch_size + 1]), + ], + [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT16, + [-1, config.num_heads * config.head_size], + ), + ], + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def create_multihead_attention_graph(config): + nodes = [ + helper.make_node( + "MultiHeadAttention", + [ + "query", + "key", + "value", + ], + ["output"], + "MultiHeadAttention_0", + num_heads=config.num_heads, + domain="com.microsoft", + ), + ] + + graph = helper.make_graph( + nodes, + "MultiHeadAttention_Graph", + [ + helper.make_tensor_value_info( + "query", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.num_heads * config.head_size, + ], + ), + ], + [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT16, + [config.batch_size, config.sequence_length, config.num_heads * config.head_size], + ), + ], + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def create_group_query_attention_graph_prompt(config, past_kv_format=Formats.BSNH, share_buffer=True): + past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 + present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length + nodes = [ + helper.make_node( + "GroupQueryAttention", + [ + "query", + "key", + "value", + "past_key" if share_buffer else "", + "past_value" if share_buffer else "", + "seqlens_k", + "total_sequence_length", + ], + ["output", "present_key", "present_value"], + "GroupQueryAttention_0", + num_heads=config.num_heads, + kv_num_heads=config.kv_num_heads, + # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, + # kv_share_buffer=1 if share_buffer else 0, + domain="com.microsoft", + ), + ] + + graph_input = [ + helper.make_tensor_value_info( + "query", + TensorProto.FLOAT16, + [ + config.batch_size, + config.q_sequence_length, + config.num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "seqlens_k", + TensorProto.INT32, + [config.batch_size], + ), + helper.make_tensor_value_info( + "total_sequence_length", + TensorProto.INT32, + [1], + ), + ] + if share_buffer: + graph_input += [ + helper.make_tensor_value_info( + "past_key", + TensorProto.FLOAT16, + [ + config.batch_size, + past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "past_value", + TensorProto.FLOAT16, + [ + config.batch_size, + past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, + config.head_size, + ], + ), + ] + + graph_output = [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT16, + [config.batch_size, config.q_sequence_length, config.num_heads * config.head_size], + ), + helper.make_tensor_value_info( + "present_key", + TensorProto.FLOAT16, + [ + config.batch_size, + present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "present_value", + TensorProto.FLOAT16, + [ + config.batch_size, + present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, + config.head_size, + ], + ), + ] + + graph = helper.make_graph( + nodes, + "GroupQueryAttention_Graph", + graph_input, + graph_output, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def create_group_query_attention_graph_past(config, past_kv_format=Formats.BSNH, share_buffer=True): + past_kv_seqlen = config.kv_sequence_length + present_kv_seqlen = ( + config.kv_sequence_length if share_buffer else config.kv_sequence_length + config.sequence_length + ) + nodes = [ + helper.make_node( + "GroupQueryAttention", + [ + "query", + "key", + "value", + "past_key", + "past_value", + "seqlens_k", + "total_sequence_length", + ], + ["output", "present_key", "present_value"], + "GroupQueryAttention_0", + num_heads=config.num_heads, + kv_num_heads=config.kv_num_heads, + # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, + # kv_share_buffer=1 if share_buffer else 0, + domain="com.microsoft", + ), + ] + + graph_input = [ + helper.make_tensor_value_info( + "query", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "past_key", + TensorProto.FLOAT16, + [ + config.batch_size, + past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "past_value", + TensorProto.FLOAT16, + [ + config.batch_size, + past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, + config.head_size, + ], + ), + ] + graph_input += [ + helper.make_tensor_value_info( + "seqlens_k", + TensorProto.INT32, + [config.batch_size], + ), + helper.make_tensor_value_info( + "total_sequence_length", + TensorProto.INT32, + [1], + ), + ] + + graph_output = [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT16, + [config.batch_size, config.sequence_length, config.num_heads * config.head_size], + ), + helper.make_tensor_value_info( + "present_key", + TensorProto.FLOAT16, + [ + config.batch_size, + present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "present_value", + TensorProto.FLOAT16, + [ + config.batch_size, + present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, + config.head_size, + ], + ), + ] + + graph = helper.make_graph( + nodes, + "GroupQueryAttention_Graph", + graph_input, + graph_output, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): + assert mode in ["full", "random", "third"] + if mode == "full": + lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) + elif mode == "random": + lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen, (batch_size, 1), device=device) + else: + lengths = torch.randint(max_seqlen // 3, max_seqlen, (batch_size, 1), device=device) + padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths + return padding_mask + + +def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, d) + k: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d) + query_padding_mask: (batch_size, seqlen), bool + key_padding_mask: (batch_size, seqlen), bool + """ + assert not (kvpacked and qkvpacked) + batch_size, seqlen_q, nheads, d = q.shape + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d) + + if query_padding_mask is not None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) + + def output_pad_fn(output_unpad): + return pad_input(output_unpad, indices_q, batch_size, seqlen_q) + + else: + q_unpad = rearrange(q, "b s h d -> (b s) h d") + cu_seqlens_q = torch.arange( + 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device + ) + max_seqlen_q = seqlen_q + + def output_pad_fn(output_unpad): + return rearrange(output_unpad, "(b s) h d -> b s h d", b=batch_size) + + if key_padding_mask is not None: + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) + v_unpad, _, _, _ = unpad_input(v, key_padding_mask) + else: + k_unpad = rearrange(k, "b s h d -> (b s) h d") + v_unpad = rearrange(v, "b s h d -> (b s) h d") + cu_seqlens_k = torch.arange( + 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device + ) + max_seqlen_k = seqlen_k + + if qkvpacked: + assert (query_padding_mask == key_padding_mask).all() + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + if query_padding_mask is not None: + + def dqkv_pad_fn(dqkv_unpad): + return pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) + + else: + + def dqkv_pad_fn(dqkv_unpad): + return rearrange(dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) + + return ( + qkv_unpad.detach().requires_grad_(), + cu_seqlens_q, + max_seqlen_q, + qkv.detach().requires_grad_(), + output_pad_fn, + dqkv_pad_fn, + ) + elif kvpacked: + kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) + kv = torch.stack([k, v], dim=2) + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + + def dkv_pad_fn(dkv_unpad): + return pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) + + else: + + def dkv_pad_fn(dkv_unpad): + return rearrange(dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) + + return ( + q_unpad.detach().requires_grad_(), + kv_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + kv.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dkv_pad_fn, + ) + else: + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + + def dk_pad_fn(dk_unpad): + return pad_input(dk_unpad, indices_k, batch_size, seqlen_k) + + else: + + def dk_pad_fn(dk_unpad): + return rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) + + return ( + q_unpad.detach().requires_grad_(), + k_unpad.detach().requires_grad_(), + v_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + k.detach().requires_grad_(), + v.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) + + +def create_inputs(config: Config, kv_packed=False, qkv_packed=True): + qkv = torch.randn( + config.batch_size, + config.sequence_length, + 3, + config.num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + key_padding_mask = generate_random_padding_mask( + config.sequence_length, config.batch_size, device="cuda", mode="random" + ) + qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( + *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, kv_packed, qkv_packed + ) + return qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn, key_padding_mask + + +def generate_token_offset(cu_seqlens, max_seqlen): + token_offset = [] + token_padset = [] # These are the indices that contain padding tokens + for i in range(1, len(cu_seqlens)): + start = i - 1 + pre_seqlen = cu_seqlens[i - 1] + seqlen = cu_seqlens[i] + token_offset += range(start * max_seqlen, (start * max_seqlen) + (seqlen - pre_seqlen)) + token_padset += range((start * max_seqlen) + (seqlen - pre_seqlen), i * max_seqlen) + return numpy.asarray(token_offset + token_padset, dtype=numpy.int32) + + +def flash_attn_varlen_qkvpacked_func(qkv_unpad, cu_seqlens, token_offset, config, causal=False): + onnx_model_str = create_packed_multihead_attention_graph(config) + qkv_unpad = torch.swapdims(qkv_unpad, 1, 2) + ort_inputs = { + "query": qkv_unpad.detach().cpu().numpy(), + "token_offset": token_offset, + "cumulative_sequence_length": cu_seqlens.cpu().numpy(), + } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) + ort_output = ort_session.run(None, ort_inputs) + output = torch.tensor(ort_output) + return output + + +def mha_func(q, k, v, config): + onnx_model_str = create_multihead_attention_graph(config) + q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) + k = torch.reshape(k, (config.batch_size, config.kv_sequence_length, -1)) + v = torch.reshape(v, (config.batch_size, config.kv_sequence_length, -1)) + ort_inputs = { + "query": q.detach().cpu().numpy(), + "key": k.detach().cpu().numpy(), + "value": v.detach().cpu().numpy(), + } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) + ort_output = ort_session.run(None, ort_inputs) + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output + + +def gqa_prompt_func(q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_format=Formats.BSNH, share_buffer=True): + onnx_model_str = create_group_query_attention_graph_prompt(config, past_kv_format, share_buffer) + q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) + past_k = k.clone() if share_buffer else None + past_v = v.clone() if share_buffer else None + new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) + new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) + if share_buffer: + ort_inputs = { + "query": q.detach().cpu().numpy(), + "key": new_k.detach().cpu().numpy(), + "value": new_v.detach().cpu().numpy(), + "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0), + "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0), + "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), + "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), + } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) + io_binding = ort_session.io_binding() + io_binding.bind_cpu_input("query", ort_inputs["query"]) + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + io_binding.bind_input( + "past_key", "cuda", 0, numpy.float16, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() + ) + io_binding.bind_input( + "past_value", + "cuda", + 0, + numpy.float16, + ort_inputs["past_value"].shape(), + ort_inputs["past_value"].data_ptr(), + ) + io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) + io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) + io_binding.bind_output("output") + io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) + io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v + else: + ort_inputs = { + "query": q.detach().cpu().numpy(), + "key": new_k.detach().cpu().numpy(), + "value": new_v.detach().cpu().numpy(), + "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), + "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), + } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) + io_binding = ort_session.io_binding() + io_binding.bind_cpu_input("query", ort_inputs["query"]) + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) + io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) + io_binding.bind_output("output") + io_binding.bind_output("present_key") + io_binding.bind_output("present_value") + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v + + +def gqa_past_func(q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_format=Formats.BSNH, share_buffer=True): + onnx_model_str = create_group_query_attention_graph_past(config, past_kv_format, share_buffer) + q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) + past_k = k.clone() + past_v = v.clone() + new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) + new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) + if share_buffer: + ort_inputs = { + "query": q.detach().cpu().numpy(), + "key": new_k.detach().cpu().numpy(), + "value": new_v.detach().cpu().numpy(), + "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0), + "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0), + "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), + "total_sequence_length": torch.tensor([config.kv_sequence_length], dtype=torch.int32) + .detach() + .cpu() + .numpy(), + } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) + io_binding = ort_session.io_binding() + io_binding.bind_cpu_input("query", ort_inputs["query"]) + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + io_binding.bind_input( + "past_key", "cuda", 0, numpy.float16, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() + ) + io_binding.bind_input( + "past_value", + "cuda", + 0, + numpy.float16, + ort_inputs["past_value"].shape(), + ort_inputs["past_value"].data_ptr(), + ) + io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) + io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) + io_binding.bind_output("output") + io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) + io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v + else: + ort_inputs = { + "query": q.detach().cpu().numpy(), + "key": new_k.detach().cpu().numpy(), + "value": new_v.detach().cpu().numpy(), + "past_key": past_k.detach().cpu().numpy(), + "past_value": past_v.detach().cpu().numpy(), + "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), + "total_sequence_length": torch.tensor( + [config.kv_sequence_length + config.sequence_length], dtype=torch.int32 + ) + .detach() + .cpu() + .numpy(), + } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) + io_binding = ort_session.io_binding() + io_binding.bind_cpu_input("query", ort_inputs["query"]) + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + io_binding.bind_cpu_input("past_key", ort_inputs["past_key"]) + io_binding.bind_cpu_input("past_value", ort_inputs["past_value"]) + io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) + io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) + io_binding.bind_output("output") + io_binding.bind_output("present_key") + io_binding.bind_output("present_value") + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v + + +def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + sq = seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + return col_idx > row_idx + sk - sq + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + upcast=True, + reorder_ops=False, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads_k, head_dim) + v: (batch_size, seqlen_k, nheads_k, head_dim) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + if causal: + causal_mask = construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device) + scores.masked_fill_(causal_mask, float("-inf")) + attention = torch.softmax(scores, dim=-1) + if causal: # Some rows are completely masked out so we fill them with zero instead of NaN + attention = attention.masked_fill(torch.all(causal_mask, dim=-1, keepdim=True), 0.0) + dropout_scaling = 1.0 / (1 - dropout_p) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +def attention_qkvpacked_ref( + qkv, key_padding_mask=None, dropout_p=0.0, dropout_mask=None, causal=False, upcast=True, reorder_ops=False +): + return attention_ref( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + key_padding_mask, + key_padding_mask, + dropout_p, + dropout_mask, + upcast=upcast, + causal=causal, + reorder_ops=reorder_ops, + ) + + +def parity_check_mha( + config, + packed, + rtol=1e-3, + atol=1e-3, +): + if packed: + qkv_unpad, cu_seqlens, _, qkv, output_pad_fn, _, key_padding_mask = create_inputs(config) + token_offset = generate_token_offset(cu_seqlens, config.sequence_length).reshape( + (config.batch_size, config.sequence_length) + ) + # ORT Flash + out_unpad = flash_attn_varlen_qkvpacked_func(qkv_unpad, cu_seqlens, token_offset, config, causal=False) + out_unpad = torch.squeeze(out_unpad, 0) + out = torch.reshape( + output_pad_fn(out_unpad), (config.batch_size, config.sequence_length, config.num_heads, config.head_size) + ) + out = out.detach().cpu().numpy() + # Pytorch to compare + out_ref, _ = attention_qkvpacked_ref(qkv, key_padding_mask, 0.0, None, causal=False) + out_ref = out_ref.detach().cpu().numpy() + else: + q = torch.randn( + config.batch_size, + config.sequence_length, + config.num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + k = torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + v = torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + out = mha_func(q, k, v, config) + out = torch.squeeze(out, 0) + out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + out = out.detach().cpu().numpy() + # Pytorch to compare + out_ref, _ = attention_ref(q, k, v, None, None, 0.0, None, causal=False) + out_ref = out_ref.detach().cpu().numpy() + + # Compare results + print( + " B:", + config.batch_size, + " S:", + config.sequence_length, + " N:", + config.num_heads, + " kvN:", + config.kv_num_heads, + " h:", + config.head_size, + " Mean Error:", + numpy.mean(numpy.abs(out - out_ref)), + numpy.allclose( + out, + out_ref, + rtol=rtol, + atol=atol, + equal_nan=True, + ), + ) + + +def parity_check_gqa_prompt( + config, + past_format=Formats.BSNH, + rtol=1e-3, + atol=1e-3, +): + q = torch.randn( + config.batch_size, + config.q_sequence_length, + config.num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + k = torch.randn( + config.batch_size, + config.buffer_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + v = torch.randn( + config.batch_size, + config.buffer_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_k = torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_v = torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + + # Pytorch to compare + k_cache_ref = k.clone() + v_cache_ref = v.clone() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + cache_seqlens = torch.tensor([config.kv_sequence_length], device="cuda").repeat(config.batch_size) + # cache_seqlens = torch.randint( + # 0, + # config.kv_sequence_length, + # (config.batch_size,), + # dtype=torch.int32, + # device="cuda", + # ) + # cache_seqlens[random.randint(0, cache_seqlens.size(dim=0) - 1)] = config.kv_sequence_length + rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") + arange = rearrange(torch.arange(config.buffer_sequence_length, device="cuda"), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + kv_seqlens = torch.tensor([config.kv_sequence_length], device="cuda").repeat(config.batch_size) + kv_seqlens_expanded = rearrange(kv_seqlens, "b -> b 1") + update_mask = arange < kv_seqlens_expanded + k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...") + v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + key_padding_mask = arange < cache_seqlens_expanded + out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) + out_ref = out_ref.detach().cpu().numpy() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + + # Flash function + out, present_k, present_v = gqa_prompt_func(q, k, v, config, new_k, new_v, cache_seqlens, past_format, True) + out = torch.squeeze(out, 0) + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) + out = out.detach().cpu().numpy() + + # Make sure past-present buffer updating correctly + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + + # Compare results + print( + "KV-buffer", + "past kv format:", + "BSNH" if past_format == Formats.BSNH else "BNSH", + " B:", + config.batch_size, + " S:", + config.q_sequence_length, + " kv S:", + config.kv_sequence_length, + " N:", + config.num_heads, + " kv N:", + config.kv_num_heads, + " h:", + config.head_size, + " Mean Error:", + numpy.mean(numpy.abs(out - out_ref)), + numpy.allclose( + out, + out_ref, + rtol=rtol, + atol=atol, + equal_nan=True, + ), + ) + + +def parity_check_gqa_prompt_no_buff( + config, + past_format=Formats.BSNH, + rtol=1e-3, + atol=1e-3, +): + q = torch.randn( + config.batch_size, + config.q_sequence_length, + config.num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_k = torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_v = torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + + # Pytorch to compare + k_cache_ref = new_k.clone() + v_cache_ref = new_v.clone() + # if past_format == Formats.BNSH: + # k_cache_ref = k_cache_ref.transpose(1, 2) + # v_cache_ref = v_cache_ref.transpose(1, 2) + cache_seqlens = torch.tensor([config.kv_sequence_length], device="cuda").repeat(config.batch_size) + # cache_seqlens = torch.randint( + # 0, + # config.kv_sequence_length, + # (config.batch_size,), + # dtype=torch.int32, + # device="cuda", + # ) + # cache_seqlens[random.randint(0, cache_seqlens.size(dim=0) - 1)] = config.kv_sequence_length + brange = rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + new_mask = brange < cache_seqlens_expanded + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True) + out_ref = out_ref.detach().cpu().numpy() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + + # Flash function + out, present_k, present_v = gqa_prompt_func(q, None, None, config, new_k, new_v, cache_seqlens, past_format, False) + out = torch.squeeze(out, 0) + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) + out = out.detach().cpu().numpy() + + # Make sure past-present buffer updating correctly + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + + # Compare results + print( + "KV-buffer", + "past kv format:", + "BSNH" if past_format == Formats.BSNH else "BNSH", + " B:", + config.batch_size, + " S:", + config.q_sequence_length, + " kv S:", + config.kv_sequence_length, + " N:", + config.num_heads, + " kv N:", + config.kv_num_heads, + " h:", + config.head_size, + " Mean Error:", + numpy.mean(numpy.abs(out - out_ref)), + numpy.allclose( + out, + out_ref, + rtol=rtol, + atol=atol, + equal_nan=True, + ), + ) + + +def parity_check_gqa_past( + config, + past_format=Formats.BSNH, + rtol=1e-3, + atol=1e-3, +): + q = torch.randn( + config.batch_size, + config.sequence_length, + config.num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + k = torch.randn( + config.batch_size, + config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + v = torch.randn( + config.batch_size, + config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_k = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_v = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + + # Pytorch to compare + k_cache_ref = k.clone() + v_cache_ref = v.clone() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + # cache_seqlens = torch.tensor([config.past_sequence_length], device="cuda").repeat(config.batch_size) + cache_seqlens = torch.randint( + 0, + config.kv_sequence_length - config.sequence_length + 1, + (config.batch_size,), + dtype=torch.int32, + device="cuda", + ) + arange = rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length + ) + k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...") + v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length + out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) + out_ref = out_ref.detach().cpu().numpy() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + + # Flash function + out, present_k, present_v = gqa_past_func(q, k, v, config, new_k, new_v, cache_seqlens, past_format, True) + out = torch.squeeze(out, 0) + out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + out = out.detach().cpu().numpy() + + # Make sure past-present buffer updating correctly + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + + # Compare results + print( + "KV-buffer", + "past kv format:", + "BSNH" if past_format == Formats.BSNH else "BNSH", + " B:", + config.batch_size, + " S:", + config.sequence_length, + " kv S:", + config.kv_sequence_length, + " N:", + config.num_heads, + " kv N:", + config.kv_num_heads, + " h:", + config.head_size, + " Mean Error:", + numpy.mean(numpy.abs(out - out_ref)), + numpy.allclose( + out, + out_ref, + rtol=rtol, + atol=atol, + equal_nan=True, + ), + ) + + +def parity_check_gqa_past_no_buff( + config, + past_format=Formats.BSNH, + rtol=1e-3, + atol=1e-3, +): + torch.manual_seed(69) + q = torch.randn( + config.batch_size, + config.sequence_length, + config.num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + k = torch.randn( + config.batch_size, + config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + v = torch.randn( + config.batch_size, + config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_k = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_v = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + + # Pytorch to compare + k_cache_ref = k.clone() + v_cache_ref = v.clone() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + k_cache_ref = torch.cat((k_cache_ref, new_k), 1) + v_cache_ref = torch.cat((v_cache_ref, new_v), 1) + # cache_seqlens = torch.tensor([config.past_sequence_length], device="cuda").repeat(config.batch_size) + cache_seqlens = torch.randint( + 0, + config.kv_sequence_length, + (config.batch_size,), + dtype=torch.int32, + device="cuda", + ) + cache_seqlens[random.randint(0, config.batch_size - 1)] = config.kv_sequence_length + arange = rearrange(torch.arange(config.kv_sequence_length + config.sequence_length, device="cuda"), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length + ) + k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...") + v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length + out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) + out_ref = out_ref.detach().cpu().numpy() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + + # Flash function + out, present_k, present_v = gqa_past_func(q, k, v, config, new_k, new_v, cache_seqlens, past_format, False) + out = torch.squeeze(out, 0) + out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + out = out.detach().cpu().numpy() + + # Make sure past-present buffer updating correctly + # assert numpy.allclose( + # present_k[:, :, :-1, :], k_cache_ref.detach().cpu().numpy()[:, :, :-1, :], rtol=rtol, atol=atol, equal_nan=True + # ) + # assert numpy.allclose( + # present_v[:, :, :-1, :], v_cache_ref.detach().cpu().numpy()[:, :, :-1, :], rtol=rtol, atol=atol, equal_nan=True + # ) + + # Compare results + print( + "NO buff", + "past kv format:", + "BSNH" if past_format == Formats.BSNH else "BNSH", + " B:", + config.batch_size, + " S:", + config.sequence_length, + " kv S:", + config.kv_sequence_length, + " N:", + config.num_heads, + " kv N:", + config.kv_num_heads, + " h:", + config.head_size, + " Mean Error:", + numpy.mean(numpy.abs(out - out_ref)), + numpy.allclose( + out, + out_ref, + rtol=rtol, + atol=atol, + equal_nan=True, + ), + ) + + +def parity_check_gqa_past_no_buff_no_mask( + config, + past_format=Formats.BSNH, + rtol=1e-3, + atol=1e-3, +): + q = torch.randn( + config.batch_size, + config.sequence_length, + config.num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + k = torch.randn( + config.batch_size, + config.past_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.past_sequence_length, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + v = torch.randn( + config.batch_size, + config.past_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.past_sequence_length, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_k = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_v = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + + # Pytorch to compare + k_cache_ref = k.clone() + v_cache_ref = v.clone() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + k_cache_ref = torch.cat((k_cache_ref, new_k), 1) + v_cache_ref = torch.cat((v_cache_ref, new_v), 1) + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + key_padding_mask = None + out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) + out_ref = out_ref.detach().cpu().numpy() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + + # Flash function + out, present_k, present_v = gqa_past_func(q, k, v, config, new_k, new_v, past_format, False) + out = torch.squeeze(out, 0) + out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + out = out.detach().cpu().numpy() + + # Make sure past-present buffer updating correctly + if past_format == Formats.BSNH: + assert numpy.allclose( + present_k, + k_cache_ref.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + ) + assert numpy.allclose( + present_v, + v_cache_ref.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + ) + else: + assert numpy.allclose( + present_k, + k_cache_ref.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + ) + assert numpy.allclose( + present_v, + v_cache_ref.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + ) + + # Compare results + print( + "Unbuffered", + "past kv format:", + "BSNH" if past_format == Formats.BSNH else "BNSH", + " B:", + config.batch_size, + " S:", + config.sequence_length, + " kv S:", + config.kv_sequence_length, + " N:", + config.num_heads, + " kv N:", + config.kv_num_heads, + " h:", + config.head_size, + " Mean Error:", + numpy.mean(numpy.abs(out - out_ref)), + numpy.allclose( + out, + out_ref, + rtol=rtol, + atol=atol, + equal_nan=True, + ), + ) + + +class TestMHA(unittest.TestCase): + def test_packed_mha(self): + if not torch.cuda.is_available() or platform.system() != "Linux": + return + major, _ = torch.cuda.get_device_capability() + if major < 8: + return + print("-------- TEST PACKED MHA ---------") + batches = [2] if pipeline_mode else [1, 5] + seqs = [8, 97, 256, 1024] if pipeline_mode else [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048] + num_h = [1, 3] if pipeline_mode else [1, 6, 16] + h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + for b in batches: + for s in seqs: + for n in num_h: + for h in h_sizes: + config = Config(b, s, s, 0, n, n, h) + parity_check_mha(config, True) + + def test_mha(self): + if not torch.cuda.is_available() or platform.system() != "Linux": + return + major, _ = torch.cuda.get_device_capability() + if major < 8: + return + print("-------- TEST MHA ---------") + batches = [2] if pipeline_mode else [1, 5] + seqs = ( + [(1, 128), (113, 211), (2048, 2048)] + if pipeline_mode + else [ + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ] + ) + num_h = [1, 3] if pipeline_mode else [1, 6, 16] + h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + for b in batches: + for s, s2 in seqs: + for n in num_h: + for h in h_sizes: + config = Config(b, s, s2, 0, n, n, h) + parity_check_mha(config, False) + + +class TestGQA(unittest.TestCase): + def test_gqa_no_past(self): + if not torch.cuda.is_available(): + return + major, minor = torch.cuda.get_device_capability() + torch.manual_seed(69) + print("-------- TEST GQA NO PAST (PROMPT CASE) ---------") + batches = [3] if pipeline_mode else [1, 3, 5] + seqs = ( + [ + (127, 127), + (35, 35), + (2000, 2000), + (200, 200), + (240, 240), + ] + if pipeline_mode + else [ + (127, 127), + (35, 35), + (2000, 2000), + (200, 200), + (240, 240), + ] + ) + num_h = [(32, 32), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + if major < 5 or (major == 5 and minor < 3): + return + print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + for b in batches: + for sq, skv in seqs: + for n, n2 in num_h: + for h in h_sizes: + for past_kv_format in [Formats.BNSH]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + parity_check_gqa_prompt(config, past_format=past_kv_format) + parity_check_gqa_prompt_no_buff(config, past_format=past_kv_format) + if major < 8 or platform.system() != "Linux": + return + print("------- FLASH ATTENTION (PROMPT CASE) --------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + for b in batches: + for sq, skv in seqs: + for n, n2 in num_h: + for h in h_sizes: + for past_kv_format in [Formats.BNSH]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + parity_check_gqa_prompt(config, past_format=past_kv_format) + parity_check_gqa_prompt_no_buff(config, past_format=past_kv_format) + + def test_gqa_past(self): + if not torch.cuda.is_available(): + return + major, minor = torch.cuda.get_device_capability() + if major < 5 or (major == 5 and minor < 3): + return + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + print("-------- TEST GQA PAST (TOKEN GEN) ---------") + print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") + batches = [5] if pipeline_mode else [1, 3, 5] + seqs = ( + [(1, 128), (1, 1024), (1, 2048)] + if pipeline_mode + else [ + (1, 128), + (1, 339), + (1, 1024), + (1, 5000), + (1, 800), + (1, 256), + (1, 799), + (1, 2048), + # (1, 128 * 512), + # (16, 128 * 512), + # (128, 128), + ] + ) + num_h = [(32, 32), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + random.seed(69) + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for past_kv_format in [Formats.BNSH]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + parity_check_gqa_past_no_buff( + config, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + if major < 8 or platform.system() != "Linux": + return + print("------- FLASH ATTENTION (TOKEN GEN) -------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for past_kv_format in [Formats.BNSH]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + parity_check_gqa_past_no_buff( + config, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + + +if __name__ == "__main__": + unittest.main() + # test_gqa = TestGQA() + # test_gqa.test_gqa_past() diff --git a/onnxruntime/test/python/transformers/test_group_norm.py b/onnxruntime/test/python/transformers/test_group_norm.py new file mode 100644 index 0000000000..bf295a65c8 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_group_norm.py @@ -0,0 +1,541 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import statistics +from dataclasses import dataclass +from enum import Enum +from time import perf_counter +from typing import Optional, Tuple + +import numpy +import torch +from onnx import TensorProto, helper + +from onnxruntime import InferenceSession +from onnxruntime.transformers.io_binding_helper import CudaSession + +torch.manual_seed(0) + + +class GroupNormOpType(Enum): + GROUP_NORM = 1 + SKIP_GROUP_NORM = 2 + + +@dataclass +class GroupNormConfig: + batch_size: int + height: int + width: int + channels: int + epsilon: float = 1e-5 + num_groups: int = 32 + activation: bool = False + channels_last: bool = True + fp16: bool = False + + op_type: GroupNormOpType = GroupNormOpType.GROUP_NORM + has_bias: bool = False + has_add_out: bool = False + broadcast_skip: int = 0 # 2 for (N, C), 4 for (N, 1, 1, C) + + def get_skip_symbolic_shape(self): + skip_shape = {0: ["N", "H", "W", "C"], 2: ["N", "C"], 4: ["N", 1, 1, "C"]} + return skip_shape[self.broadcast_skip] + + def get_skip_shape(self): + skip_shape = { + 0: [self.batch_size, self.height, self.width, self.channels], + 2: [self.batch_size, self.channels], + 4: [self.batch_size, 1, 1, self.channels], + } + return skip_shape[self.broadcast_skip] + + def broadcast(self, skip: torch.Tensor): + if self.broadcast_skip == 2: + return skip.reshape(self.batch_size, 1, 1, self.channels) + + return skip + + @staticmethod + def create( + b: int, + h: int, + w: int, + c: int, + fp16: bool = False, + activation: bool = False, + template: int = 0, + num_groups: int = 32, + ): + if template == 0: + return GroupNormConfig( + b, h, w, c, fp16=fp16, activation=activation, op_type=GroupNormOpType.GROUP_NORM, num_groups=num_groups + ) + + if template == 1: + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=True, + has_add_out=True, + broadcast_skip=0, + num_groups=num_groups, + ) + + if template == 2: + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=False, + has_add_out=False, + broadcast_skip=2, + num_groups=num_groups, + ) + + if template == 3: + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=True, + has_add_out=False, + broadcast_skip=4, + num_groups=num_groups, + ) + + if template == 4: # No bias + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=False, + has_add_out=True, + broadcast_skip=0, + num_groups=num_groups, + ) + + if template == 5: # No bias, no add_out + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=False, + has_add_out=False, + broadcast_skip=0, + num_groups=num_groups, + ) + + return None + + +def create_group_norm_graph(config: GroupNormConfig) -> bytes: + inputs = ["input", "gamma", "beta"] + outputs = ["output"] + op_type = "GroupNorm" + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + op_type = "SkipGroupNorm" + inputs = [*inputs, "skip"] + if config.has_bias: + inputs = [*inputs, "bias"] + if config.has_add_out: + outputs = [*outputs, "add_out"] + + nodes = [ + helper.make_node( + op_type, + inputs, + outputs, + op_type + "_0", + activation=int(config.activation), + channels_last=int(config.channels_last), + epsilon=config.epsilon, + groups=config.num_groups, + domain="com.microsoft", + ), + ] + + float_type = TensorProto.FLOAT16 if config.fp16 else TensorProto.FLOAT + + input_shapes = [ + helper.make_tensor_value_info("input", float_type, ["N", "H", "W", "C"]), + helper.make_tensor_value_info("gamma", TensorProto.FLOAT, ["C"]), + helper.make_tensor_value_info("beta", TensorProto.FLOAT, ["C"]), + ] + output_shapes = [ + helper.make_tensor_value_info("output", float_type, ["N", "H", "W", "C"]), + ] + + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + input_shapes = [ + *input_shapes, + helper.make_tensor_value_info("skip", float_type, config.get_skip_symbolic_shape()), + ] + if config.has_bias: + input_shapes = [*input_shapes, helper.make_tensor_value_info("bias", float_type, ["C"])] + if config.has_add_out: + output_shapes = [*output_shapes, helper.make_tensor_value_info("add_out", float_type, ["N", "H", "W", "C"])] + + graph = helper.make_graph( + nodes, + "Group_Norm_Graph", + input_shapes, + output_shapes, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def group_norm_ort( + src: torch.Tensor, + gamma: torch.Tensor, + beta: torch.Tensor, + skip: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + config: GroupNormConfig, + measure_latency=False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[float]]: + onnx_model_str = create_group_norm_graph(config) + ort_session = InferenceSession(onnx_model_str, providers=["CUDAExecutionProvider"]) + + session = CudaSession(ort_session, device=torch.device("cuda:0")) + + io_shape = { + "input": [config.batch_size, config.height, config.width, config.channels], + "gamma": [config.channels], + "beta": [config.channels], + "output": [config.batch_size, config.height, config.width, config.channels], + } + + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + io_shape["skip"] = config.get_skip_shape() + if config.has_bias: + io_shape["bias"] = [config.channels] + if config.has_add_out: + io_shape["add_out"] = [config.batch_size, config.height, config.width, config.channels] + + session.allocate_buffers(io_shape) + + ort_inputs = { + "input": src, + "gamma": gamma, + "beta": beta, + } + + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + ort_inputs["skip"] = skip + if config.has_bias: + ort_inputs["bias"] = bias + + ort_outputs = session.infer(ort_inputs) + output = ort_outputs["output"] + + add_out = ( + ort_outputs["add_out"] if config.op_type == GroupNormOpType.SKIP_GROUP_NORM and config.has_add_out else None + ) + + if measure_latency: + latency_list = [] + for _ in range(10000): + start_time = perf_counter() + session.infer(ort_inputs) + end_time = perf_counter() + latency_list.append(end_time - start_time) + average_latency = statistics.mean(latency_list) + return output, add_out, average_latency + + return output, add_out, None + + +def group_norm_torch( + src: torch.Tensor, + gamma: torch.Tensor, + beta: torch.Tensor, + skip: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + config: GroupNormConfig, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + add_out = src + + if skip is not None: + assert config.op_type == GroupNormOpType.SKIP_GROUP_NORM + add_out = add_out + config.broadcast(skip) + + if bias is not None: + assert config.op_type == GroupNormOpType.SKIP_GROUP_NORM + add_out = add_out + bias.reshape(1, 1, 1, bias.shape[0]) + + x = add_out + if config.channels_last: + x = add_out.clone().permute(0, 3, 1, 2) # from NHWC to NCHW + + weight = gamma.to(x.dtype) + bias = beta.to(x.dtype) + output = torch.nn.functional.group_norm(x, config.num_groups, weight=weight, bias=bias, eps=config.epsilon) + + if config.activation: + torch.nn.functional.silu(output, inplace=True) + + if config.channels_last: + output = output.permute(0, 2, 3, 1) # from NCHW to NHWC + + return output, add_out + + +def print_tensor(name, tensor): + # Print in the format that could be directly added to unit tests in C++. + torch.set_printoptions(precision=6, sci_mode=False, linewidth=100, profile="full", threshold=1000) + print(name) + if tensor is not None: + print("shape", tensor.shape) + text = str(tensor.clone().flatten()) + print(text.replace("[", "[\n").replace("]", ",\n]").replace(",", "f,")) + else: + print(tensor) + + +def run_parity(config, measure_latency=True, verbose=False): + float_type = torch.float16 if config.fp16 else torch.float32 + + input_tensor = torch.randn( + config.batch_size, + config.height, + config.width, + config.channels, + device="cuda", + dtype=float_type, + requires_grad=False, + ) + + gamma = torch.randn( + config.channels, + device="cuda", + dtype=torch.float32, + requires_grad=False, + ) + + beta = torch.randn( + config.channels, + device="cuda", + dtype=torch.float32, + requires_grad=False, + ) + + skip = None + bias = None + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + skip = torch.randn( + *config.get_skip_shape(), + device="cuda", + dtype=float_type, + requires_grad=False, + ) + if config.has_bias: + bias = torch.randn( + config.channels, + device="cuda", + dtype=float_type, + requires_grad=False, + ) + + if verbose: + print(config) + print_tensor("input", input_tensor) + print_tensor("gamma", gamma) + print_tensor("beta", beta) + print_tensor("skip", skip) + print_tensor("bias", bias) + + out_ort, ort_add_out, latency = group_norm_ort( + input_tensor, gamma, beta, skip, bias, config, measure_latency=measure_latency + ) + + if verbose: + print_tensor("out_ort", out_ort) + print_tensor("ort_add_out", ort_add_out) + + torch_out, torch_add_out = group_norm_torch(input_tensor, gamma, beta, skip, bias, config) + + if verbose: + print_tensor("torch_out", torch_out) + print_tensor("torch_add_out", torch_add_out) + + average_diff = numpy.mean(numpy.abs(out_ort.detach().cpu().numpy() - torch_out.detach().cpu().numpy())) + + is_close = numpy.allclose( + out_ort.detach().cpu().numpy(), + torch_out.detach().cpu().numpy(), + rtol=1e-1 if config.fp16 else 1e-3, + atol=1e-1 if config.fp16 else 1e-3, + equal_nan=True, + ) + + is_add_out_close = ( + numpy.allclose( + ort_add_out.detach().cpu().numpy(), + torch_add_out.detach().cpu().numpy(), + rtol=1e-1 if config.fp16 else 1e-3, + atol=1e-1 if config.fp16 else 1e-3, + equal_nan=True, + ) + if ort_add_out is not None + else "" + ) + + # Compare results + print( + config.op_type.name, + " B:", + config.batch_size, + " H:", + config.height, + " W:", + config.width, + " C:", + config.channels, + " G:", + config.num_groups, + " activation:", + int(config.activation), + " channels_last:", + int(config.channels_last), + " fp16:", + int(config.fp16), + f" Latency(μs): {int(latency * 1e6)}" if isinstance(latency, float) else "", + " AvgDiff:", + average_diff, + " Pass:", + is_close, + is_add_out_close, + ) + + +def get_latent_height_width(): + default_size = [(512, 512), (768, 768), (1024, 1024)] + small_img_size = [(512, 768), (768, 512)] + xl_img_size = [ + (1152, 896), + (896, 1152), + (1216, 832), + (832, 1216), + (1344, 768), + (768, 1344), + (1536, 640), + (640, 1536), + ] + return [(int(h / 8), int(w / 8)) for (h, w) in default_size + small_img_size + xl_img_size] + + +def get_channels(): + return [128, 256, 512, 1024, 2048, 320, 640, 960, 1920, 2560, 384, 768, 1536, 3072, 1152, 2304] + + +def run_activation(template: int, fp16, measure_latency=False): + print("Test GroupNorm with Silu Activation for ", "fp16" if fp16 else "fp32") + for b in [2]: + for h, w in get_latent_height_width(): + for c in get_channels(): + config = GroupNormConfig.create(b, h, w, c, fp16=fp16, activation=True, template=template) + run_parity(config, measure_latency=measure_latency) + + +def run_no_activation(template: int, fp16, measure_latency=False): + print("Test GroupNorm without Activation for ", "fp16" if fp16 else "fp32") + for b in [1, 2, 4]: + for h, w in get_latent_height_width(): + for c in get_channels(): + config = GroupNormConfig.create(b, h, w, c, fp16=fp16, template=template) + run_parity(config, measure_latency=measure_latency) + + +def run_all_groups(template: int, fp16, measure_latency=False): + group_sizes = [1, 2, 4, 8, 16, 32] + print("Test GroupNorm for different group sizes:", group_sizes) + for group_size in group_sizes: + for h, w in get_latent_height_width()[:3]: + for c in get_channels()[:2]: + config = GroupNormConfig.create(2, h, w, c, fp16=fp16, num_groups=group_size, template=template) + run_parity(config, measure_latency=measure_latency) + + +def run_odd_channels(template: int, fp16, measure_latency=False): + # Test some random number of channels that can be divisible by 2 * num_groups + for h, w in get_latent_height_width(): + for c in [448, 704, 832, 1664, 2240, 2688, 2880, 3008]: + config = GroupNormConfig.create(2, h, w, c, fp16=fp16, num_groups=32, template=template) + run_parity(config, measure_latency=measure_latency) + + +def run_small_inputs(template: int, fp16): + config = GroupNormConfig.create(2, 2, 2, 16, fp16=fp16, activation=False, num_groups=4, template=template) + run_parity(config, measure_latency=False) + + config = GroupNormConfig.create(1, 1, 1, 64, fp16=fp16, activation=False, num_groups=8, template=template) + run_parity(config, measure_latency=False) + + config = GroupNormConfig.create(1, 1, 1, 64, fp16=fp16, activation=True, num_groups=8, template=template) + run_parity(config, measure_latency=False) + + +def run_performance(fp16): + # Run perf test to tune parameters for given number of channels. + for h, w in get_latent_height_width()[:3]: + for c in get_channels(): + config = GroupNormConfig.create(2, h, w, c, fp16=fp16, num_groups=32, template=0) + run_parity(config, measure_latency=True) + + +def run_all(template: int): + for fp16 in [True, False]: + run_small_inputs(template, fp16) + run_odd_channels(template, fp16) + run_all_groups(template, fp16) + run_activation(template, fp16) + run_no_activation(template, fp16) + + +def run_not_implemented(): + # Expect failure. Check whether the error message is expected. + try: + config = GroupNormConfig(1, 2, 2, 513, num_groups=3) + run_parity(config) + except RuntimeError as e: + assert "GroupNorm in CUDA does not support the input: n=1 h=2 w=2 c=513 groups=3" in str(e) + + +def main(): + run_performance(True) + + run_not_implemented() + + for template in range(6): + run_all(template) + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py b/onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py new file mode 100644 index 0000000000..dca250f39f --- /dev/null +++ b/onnxruntime/test/python/transformers/test_optimizer_stable_diffusion.py @@ -0,0 +1,271 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import os +import shutil +import unittest + +import numpy as np +import pytest +from parity_utilities import find_transformers_source + +if find_transformers_source(): + from compare_bert_results import run_test + from fusion_options import FusionOptions + from optimizer import optimize_model +else: + from onnxruntime.transformers.compare_bert_results import run_test + from onnxruntime.transformers.fusion_options import FusionOptions + from onnxruntime.transformers.optimizer import optimize_model + +if find_transformers_source(["models", "stable_diffusion"]): + from optimize_pipeline import main as optimize_stable_diffusion +else: + from onnxruntime.transformers.models.stable_diffusion.optimize_pipeline import main as optimize_stable_diffusion + + +TINY_MODELS = { + "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", + "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", +} + + +class TestStableDiffusionOptimization(unittest.TestCase): + def verify_node_count(self, onnx_model, expected_node_count, test_name): + for op_type, count in expected_node_count.items(): + if len(onnx_model.get_nodes_by_op_type(op_type)) != count: + print(f"Counters is not expected in test: {test_name}") + for op, counter in expected_node_count.items(): + print(f"{op}: {len(onnx_model.get_nodes_by_op_type(op))} expected={counter}") + + self.assertEqual(len(onnx_model.get_nodes_by_op_type(op_type)), count) + + def verify_clip_optimizer(self, clip_onnx_path, optimized_clip_onnx_path, expected_counters, float16=False): + fusion_options = FusionOptions("clip") + m = optimize_model( + clip_onnx_path, + model_type="clip", + num_heads=0, + hidden_size=0, + opt_level=0, + optimization_options=fusion_options, + use_gpu=True, + ) + self.verify_node_count(m, expected_counters, "test_clip") + + if float16: + m.convert_float_to_float16( + keep_io_types=True, + ) + print(m.get_operator_statistics()) + m.save_model_to_file(optimized_clip_onnx_path) + + threshold = 1e-2 if float16 else 3e-3 + max_abs_diff, passed = run_test( + clip_onnx_path, + optimized_clip_onnx_path, + output_dir=None, + batch_size=1, + sequence_length=77, + use_gpu=True, + test_cases=10, + seed=1, + verbose=False, + rtol=1e-1, + atol=threshold, + input_ids_name="input_ids", + segment_ids_name=None, + input_mask_name=None, + mask_type=0, + ) + + self.assertLess(max_abs_diff, threshold) + self.assertTrue(passed) + + @pytest.mark.slow + def test_clip_sd(self): + save_directory = "tiny-random-stable-diffusion" + if os.path.exists(save_directory): + shutil.rmtree(save_directory, ignore_errors=True) + + model_type = "stable-diffusion" + model_name = TINY_MODELS[model_type] + + from optimum.onnxruntime import ORTStableDiffusionPipeline + + base = ORTStableDiffusionPipeline.from_pretrained(model_name, export=True) + base.save_pretrained(save_directory) + + clip_onnx_path = os.path.join(save_directory, "text_encoder", "model.onnx") + optimized_clip_onnx_path = os.path.join(save_directory, "text_encoder", "opt.onnx") + self.verify_clip_optimizer( + clip_onnx_path, + optimized_clip_onnx_path, + expected_counters={ + "EmbedLayerNormalization": 0, + "Attention": 5, + "SkipLayerNormalization": 10, + "LayerNormalization": 1, + "Gelu": 0, + "BiasGelu": 0, + }, + float16=True, + ) + + @pytest.mark.slow + def test_clip_sdxl(self): + save_directory = "tiny-random-stable-diffusion-xl" + if os.path.exists(save_directory): + shutil.rmtree(save_directory, ignore_errors=True) + + model_type = "stable-diffusion-xl" + model_name = TINY_MODELS[model_type] + + from optimum.onnxruntime import ORTStableDiffusionXLPipeline + + base = ORTStableDiffusionXLPipeline.from_pretrained(model_name, export=True) + base.save_pretrained(save_directory) + + clip_onnx_path = os.path.join(save_directory, "text_encoder", "model.onnx") + optimized_clip_onnx_path = os.path.join(save_directory, "text_encoder", "opt.onnx") + self.verify_clip_optimizer( + clip_onnx_path, + optimized_clip_onnx_path, + expected_counters={ + "EmbedLayerNormalization": 0, + "Attention": 5, + "SkipLayerNormalization": 10, + "LayerNormalization": 1, + "Gelu": 0, + "BiasGelu": 5, + }, + ) + + clip_onnx_path = os.path.join(save_directory, "text_encoder_2", "model.onnx") + optimized_clip_onnx_path = os.path.join(save_directory, "text_encoder_2", "opt.onnx") + self.verify_clip_optimizer( + clip_onnx_path, + optimized_clip_onnx_path, + expected_counters={ + "EmbedLayerNormalization": 0, + "Attention": 5, + "SkipLayerNormalization": 10, + "LayerNormalization": 1, + "Gelu": 0, + "BiasGelu": 5, + }, + ) + + @pytest.mark.slow + def test_optimize_sdxl_fp32(self): + save_directory = "tiny-random-stable-diffusion-xl" + if os.path.exists(save_directory): + shutil.rmtree(save_directory, ignore_errors=True) + + model_type = "stable-diffusion-xl" + model_name = TINY_MODELS[model_type] + + from optimum.onnxruntime import ORTStableDiffusionXLPipeline + + baseline = ORTStableDiffusionXLPipeline.from_pretrained(model_name, export=True) + if not os.path.exists(save_directory): + baseline.save_pretrained(save_directory) + + batch_size, num_images_per_prompt, height, width = 2, 2, 64, 64 + latents = baseline.prepare_latents( + batch_size * num_images_per_prompt, + baseline.unet.config["in_channels"], + height, + width, + dtype=np.float32, + generator=np.random.RandomState(0), + ) + + optimized_directory = "tiny-random-stable-diffusion-xl-optimized" + argv = [ + "--input", + save_directory, + "--output", + optimized_directory, + "--disable_group_norm", + "--disable_bias_splitgelu", + "--overwrite", + ] + optimize_stable_diffusion(argv) + + treatment = ORTStableDiffusionXLPipeline.from_pretrained(optimized_directory, provider="CUDAExecutionProvider") + inputs = { + "prompt": ["starry night by van gogh"] * batch_size, + "num_inference_steps": 3, + "num_images_per_prompt": num_images_per_prompt, + "height": height, + "width": width, + "guidance_rescale": 0.1, + "output_type": "np", + } + + ort_outputs_1 = baseline(latents=latents, **inputs) + ort_outputs_2 = treatment(latents=latents, **inputs) + self.assertTrue(np.allclose(ort_outputs_1.images[0], ort_outputs_2.images[0], atol=1e-3)) + + @pytest.mark.slow + def test_optimize_sdxl_fp16(self): + """This tests optimized fp16 pipeline, and result is deterministic for a given seed""" + save_directory = "tiny-random-stable-diffusion-xl" + if os.path.exists(save_directory): + shutil.rmtree(save_directory, ignore_errors=True) + + model_type = "stable-diffusion-xl" + model_name = TINY_MODELS[model_type] + + from optimum.onnxruntime import ORTStableDiffusionXLPipeline + + baseline = ORTStableDiffusionXLPipeline.from_pretrained(model_name, export=True) + if not os.path.exists(save_directory): + baseline.save_pretrained(save_directory) + + optimized_directory = "tiny-random-stable-diffusion-xl-optimized-fp16" + argv = [ + "--input", + save_directory, + "--output", + optimized_directory, + "--disable_group_norm", + "--disable_bias_splitgelu", + "--float16", + "--overwrite", + ] + optimize_stable_diffusion(argv) + + fp16_pipeline = ORTStableDiffusionXLPipeline.from_pretrained( + optimized_directory, provider="CUDAExecutionProvider" + ) + batch_size, num_images_per_prompt, height, width = 1, 1, 64, 64 + inputs = { + "prompt": ["starry night by van gogh"] * batch_size, + "num_inference_steps": 3, + "num_images_per_prompt": num_images_per_prompt, + "height": height, + "width": width, + "guidance_rescale": 0.1, + "output_type": "latent", + } + + seed = 123 + np.random.seed(seed) + ort_outputs_1 = fp16_pipeline(**inputs) + + np.random.seed(seed) + ort_outputs_2 = fp16_pipeline(**inputs) + + np.random.seed(seed) + ort_outputs_3 = fp16_pipeline(**inputs) + + self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_2.images[0])) + self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_parity_huggingface_gpt_attention.py b/onnxruntime/test/python/transformers/test_parity_huggingface_gpt_attention.py index ad4117f997..85b30bea4f 100644 --- a/onnxruntime/test/python/transformers/test_parity_huggingface_gpt_attention.py +++ b/onnxruntime/test/python/transformers/test_parity_huggingface_gpt_attention.py @@ -339,7 +339,7 @@ def verify_attention( ort_outputs = onnxruntime_inference(ort_session, input_hidden_states, attention_mask, layer_past) - tolerance = 1e-03 if float16 else 1e-05 + tolerance = 1e-02 if float16 else 1e-04 is_all_close, max_diff = compare_outputs(torch_outputs, ort_outputs, atol=tolerance, verbose=True) max_diffs.append(max_diff) if is_all_close: diff --git a/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py b/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py new file mode 100644 index 0000000000..b17ae5f69a --- /dev/null +++ b/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py @@ -0,0 +1,450 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + + +# Notes +# 1) The test cases in this file are for the following LLaMA-2 scenarios: +# - Microsoft rotary embeddings with interleaved = True +# - Prompt generation +# - Token generation +# - Hugging Face rotary embeddings (equal to Microsoft rotary embeddings with interleaved = False) +# - Prompt generation +# - Token generation +# +# 2) Shapes of position ids in ORT and `interleaved` for LLaMA-2 scenarios: +# - Microsoft model: When shape of position ids == (1), interleaved = True +# - Hugging Face model: When shape of position ids == (batch_size, sequence_length), interleaved = False + + +import unittest +from copy import deepcopy + +import numpy as np +import torch +import torch.nn as nn +from onnx import TensorProto, helper + +import onnxruntime as ort + + +class SampleInputConfig: + def __init__( + self, + batch_size=2, + sequence_length=8, + num_heads=4, + head_size=6, + max_sequence_length=16, + ): + self.batch_size = batch_size + self.sequence_length = sequence_length + self.num_heads = num_heads + self.head_size = head_size + self.hidden_size = self.num_heads * self.head_size + self.max_sequence_length = max_sequence_length + + +# LLaMA Hugging Face model +class LlamaHFRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device="cpu"): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def get_cos_sin_cache(self, seq_len=None, device=torch.device("cpu"), dtype=torch.float32): # noqa: B008 + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=device, dtype=dtype) + + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=dtype), + ) + + def rotate_half(self, x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope_bnsh(self, x, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + x_embed = (x * cos) + (self.rotate_half(x) * sin) + return x_embed + + def apply_rope_bsnh(self, x, cos, sin, position_ids): + # Two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze() # [seq_len, dim] + sin = sin.squeeze() # [seq_len, dim] + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + x_embed = (x * cos) + (self.rotate_half(x) * sin) + return x_embed + + def forward(self, x, cos, sin, pos_ids, x_format="bnsh"): + if x_format == "bnsh": + return self.apply_rope_bnsh(x, cos, sin, pos_ids) + return self.apply_rope_bsnh(x, cos, sin, pos_ids) + + +# LLaMA Microsoft model +class LlamaMSRotaryEmbedding(nn.Module): + def __init__(self, hidden_size, num_heads, max_sequence_length): + super().__init__() + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.max_sequence_length = max_sequence_length + + def get_cos_sin_cache(self, theta=10000.0, head_scale=1.0, device="cpu", dtype=torch.float32): + hidden_size = self.hidden_size + n_heads = self.num_heads + max_seq_len = self.max_sequence_length + + # Precalculate rotary matrices for the sequence + # According to "Attention Is All You Need", theta_i = 10000 ^ (2 * (i - 1)/dim), i in [1, 2, ..., dim//2] + head_dim = head_scale * hidden_size / n_heads + + pos = torch.arange(0, 2 * (head_dim // 2), step=2, device=device, dtype=dtype) + freqs = 1.0 / (theta ** (pos / head_dim)) + + idx = torch.arange(max_seq_len, device=freqs.device) + freqs = torch.outer(idx, freqs) + + cos = torch.reshape(torch.cos(freqs), [1, max_seq_len, 1, -1]) + sin = torch.reshape(torch.sin(freqs), [1, max_seq_len, 1, -1]) + dtype = torch.get_default_dtype() + + return cos.to(dtype), sin.to(dtype) + + def rotate_tensor( + self, + x: torch.Tensor, # BxSxNxH + cos: torch.Tensor, # 1xSx1x(H/2) + sin: torch.Tensor, # 1xSx1x(H/2) + pos: int, + interleaved: bool, + ): + # Dimension of x is [batch_size, seq_len, n_heads, head_dim] + rot_dim = 2 * cos.shape[3] + + # Dolly requires partial rotation + x_rot = x[:, :, :, :rot_dim] + + if interleaved: + x1 = x_rot[:, :, :, 0::2] + x2 = x_rot[:, :, :, 1::2] + else: + half = x_rot.shape[-1] // 2 + x1 = x[:, :, :, 0:half] + x2 = x[:, :, :, half : 2 * half] + + seq_len = x.shape[1] + cos_x = cos[:, pos : pos + seq_len, :, :] + sin_x = sin[:, pos : pos + seq_len, :, :] + + # cos_x: (1, S, 1, H/2) + # sin_x: (1, S, 1, H/2) + # x1: (B, S, N, H/2) + # x2: (B, S, N, H/2) + real = cos_x * x1 - sin_x * x2 + imag = sin_x * x1 + cos_x * x2 + + if interleaved: + x_rot[:, :, :, 0::2] = real + x_rot[:, :, :, 1::2] = imag + else: + x_rot = torch.cat((real, imag), dim=-1) + + return torch.cat((x_rot, x[:, :, :, rot_dim:]), dim=-1) + + def forward(self, x, cos, sin, pos, interleaved): + return self.rotate_tensor(x, cos, sin, pos, interleaved) + + +class TestLlamaRotaryEmbedding(unittest.TestCase): + def setUp(self): + self.config = SampleInputConfig() + self.llama_hf = LlamaHFRotaryEmbedding(self.config.head_size, self.config.max_sequence_length) + self.llama_ms = LlamaMSRotaryEmbedding( + self.config.hidden_size, self.config.num_heads, self.config.max_sequence_length + ) + + seed = 2 + np.random.seed(seed) + torch.manual_seed(seed) + torch.set_printoptions(sci_mode=False) + + def create_onnx_graph(self, x_shape, pos_shape, cos, sin, interleaved): + inputs = [ + helper.make_tensor_value_info( + name="input", + elem_type=TensorProto.FLOAT, + shape=list(x_shape), + ), + helper.make_tensor_value_info( + name="position_ids", + elem_type=TensorProto.INT64, + shape=list(pos_shape), + ), + ] + outputs = [ + helper.make_tensor_value_info( + name="output", + elem_type=TensorProto.FLOAT, + shape=list(x_shape), + ), + ] + + initializers = [ + helper.make_tensor( + name="cos_cache", + data_type=TensorProto.FLOAT, + dims=list(torch.squeeze(cos).shape), + vals=cos.flatten().tolist(), + ), + helper.make_tensor( + name="sin_cache", + data_type=TensorProto.FLOAT, + dims=list(torch.squeeze(sin).shape), + vals=sin.flatten().tolist(), + ), + ] + nodes = [ + helper.make_node( + op_type="RotaryEmbedding", + inputs=["input", "position_ids", "cos_cache", "sin_cache"], + outputs=["output"], + interleaved=interleaved, + name="RotaryEmbedding_0", + domain="com.microsoft", + ), + ] + + graph = helper.make_graph( + nodes=nodes, + name="RotaryEmbedding_Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + opset_import = helper.make_opsetid(domain="com.microsoft", version=1) + model = helper.make_model(graph, opset_imports=[opset_import]) + return model.SerializeToString() + + def get_eps(self): + eps = ["CPUExecutionProvider", "CUDAExecutionProvider"] + return list(filter(lambda ep: ep in ort.get_available_providers(), eps)) + + def run_ort_ep_tests(self, onnx_graph, inputs_ort, expected_output_bsnh): + eps = self.get_eps() + for ep in eps: + sess = ort.InferenceSession(onnx_graph, providers=[ep]) + output_ort = sess.run(None, inputs_ort)[0] + output_ort = output_ort.reshape( + (self.config.batch_size, inputs_ort["input"].shape[1], self.config.num_heads, self.config.head_size) + ) + + # Compare outputs as BxSxNxH + self.assertTrue(np.allclose(expected_output_bsnh, output_ort)) + + # apply_rope(x_bnsh) == apply_rope(x_bsnh).transpose(1,2) + def test_hf_bnsh_and_hf_bsnh(self): + x_bnsh = torch.randn( + self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size + ) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_hf = torch.stack([torch.arange(0, self.config.sequence_length) for _ in range(self.config.batch_size)]) + + x_bnsh_after_rope = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_hf) # output is BxNxSxH + x_bsnh_after_rope = self.llama_hf( + x_bnsh.transpose(1, 2), cos_hf.transpose(1, 2), sin_hf.transpose(1, 2), pos_hf, "bsnh" + ) # output is BxSxNxH + + self.assertTrue(torch.allclose(x_bnsh_after_rope, x_bsnh_after_rope.transpose(1, 2))) + + # HF rotary == MSFT rotary non-interleaved + def test_hf_rotary_and_msft_rotary_noninterleaved(self): + x_bnsh = torch.randn( + self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size + ) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_hf = torch.stack([torch.arange(0, self.config.sequence_length) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_hf) # output is BxNxSxH + + x_bsnh = x_bnsh.transpose(1, 2) + x_bsd = deepcopy(x_bsnh) # deepcopy to avoid changes made by self.llama_ms forward pass + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = 0 + output_ms = ( + self.llama_ms(x_bsd, cos_ms, sin_ms, pos_ms, interleaved=False).detach().cpu().numpy() # output is BxSxNxH + ) + + # Compare caches as Mx(H/2) + self.assertTrue( + torch.allclose(self.llama_hf.cos_cached.squeeze()[:, : (self.config.head_size // 2)], cos_ms.squeeze()) + ) + self.assertTrue( + torch.allclose(self.llama_hf.sin_cached.squeeze()[:, : (self.config.head_size // 2)], sin_ms.squeeze()) + ) + + # Compare outputs as BxSxNxH + self.assertTrue(np.allclose(output_hf.transpose(1, 2).detach().cpu().numpy(), output_ms)) + + # Prompt step, interleaved = true, pos ids shape = (1) + def test_msft_prompt_rotary_interleaved(self): + # Calculated this way to match the data in rotary_embedding_op_test.cc + x_bnsh = torch.randn( + self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size + ) + x_bsnh = x_bnsh.transpose(1, 2) + x_bsd = deepcopy(x_bsnh) # deepcopy to avoid changes made by self.llama_ms forward pass + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = 0 + output_ms = self.llama_ms(deepcopy(x_bsnh), cos_ms, sin_ms, pos_ms, interleaved=True).detach().cpu().numpy() + + x_bsd = x_bsd.reshape(self.config.batch_size, self.config.sequence_length, self.config.hidden_size) + pos_ms = torch.tensor([pos_ms]) + onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=True) + inputs_ort = { + "input": x_bsd.detach().cpu().numpy(), + "position_ids": pos_ms.detach().cpu().numpy(), + } + + # Compare inputs/outputs as BxSxNxH + self.assertTrue(np.allclose(x_bsnh.flatten(), x_bsd.flatten())) + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_ms) + + # Token generation step, interleaved = true, pos ids shape = (1) + def test_msft_token_rotary_interleaved(self): + # Calculated this way to match the data in rotary_embedding_op_test.cc + x_bnsh = torch.randn( + self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size + ) + x_bsnh = x_bnsh.transpose(1, 2) + x_bsd = deepcopy(x_bsnh) # deepcopy to avoid changes made by self.llama_ms forward pass + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = 2 + output_ms = self.llama_ms(deepcopy(x_bsnh), cos_ms, sin_ms, pos_ms, interleaved=True).detach().cpu().numpy() + + x_bsd = x_bsd.reshape(self.config.batch_size, self.config.sequence_length, self.config.hidden_size) + pos_ms = torch.tensor([pos_ms]) + onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=True) + inputs_ort = { + "input": x_bsd.detach().cpu().numpy(), + "position_ids": pos_ms.detach().cpu().numpy(), + } + + # Compare inputs/outputs as BxSxNxH + self.assertTrue(np.allclose(x_bsnh.flatten(), x_bsd.flatten())) + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_ms) + + # Prompt step, interleaved = false, pos ids shape = (batch_size, sequence_length) + def test_hf_prompt_rotary_batched_pos_ids(self): + x_bnsh = torch.randn( + self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size + ) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_ids = torch.stack([torch.arange(0, self.config.sequence_length) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_ids) # output is BxNxSxH + + x_bsnh = x_bnsh.transpose(1, 2) + x_bsd = x_bsnh.reshape(self.config.batch_size, self.config.sequence_length, self.config.hidden_size) + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ids.shape, cos_ms, sin_ms, interleaved=False) + inputs_ort = { + "input": x_bsd.detach().cpu().numpy(), + "position_ids": pos_ids.detach().cpu().numpy(), + } + + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.transpose(1, 2).detach().cpu().numpy()) + + # Token generation step, interleaved = false, pos ids shape = (batch_size, sequence_length) + def test_hf_token_rotary_batched_pos_ids(self): + x_bnsh = torch.randn(self.config.batch_size, self.config.num_heads, 1, self.config.head_size) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_ids = torch.stack([torch.tensor([2]) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_ids) # output is BxNxSxH + + x_bsnh = x_bnsh.transpose(1, 2) + x_bsd = x_bsnh.reshape(self.config.batch_size, 1, self.config.hidden_size) + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ids.shape, cos_ms, sin_ms, interleaved=False) + inputs_ort = { + "input": x_bsd.detach().cpu().numpy(), + "position_ids": pos_ids.detach().cpu().numpy(), + } + + # Compare outputs as BxSxNxH + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.transpose(1, 2).detach().cpu().numpy()) + + # Bonus test: Prompt step, interleaved = false, pos ids shape = (1) + def test_hf_prompt_rotary_one_pos_id(self): + x_bnsh = torch.randn( + self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size + ) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_hf = torch.stack([torch.arange(0, self.config.sequence_length) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_hf) # output is BxNxSxH + + x_bsnh = x_bnsh.transpose(1, 2) + x_bsd = x_bsnh.reshape(self.config.batch_size, self.config.sequence_length, self.config.hidden_size) + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = torch.tensor([0]) + onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=False) + inputs_ort = { + "input": x_bsd.detach().cpu().numpy(), + "position_ids": pos_ms.detach().cpu().numpy(), + } + + # Compare outputs as BxSxNxH + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.transpose(1, 2).detach().cpu().numpy()) + + # Bonus test: Token generation step, interleaved = false, pos ids shape = (1) + def test_hf_token_rotary_one_pos_id(self): + x_bnsh = torch.randn(self.config.batch_size, self.config.num_heads, 1, self.config.head_size) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_ids = torch.stack([torch.tensor([2]) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_ids) # output is BxNxSxH + + x_bsnh = x_bnsh.transpose(1, 2) + x_bsd = x_bsnh.reshape(self.config.batch_size, 1, self.config.hidden_size) + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = torch.tensor([2]) + onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=False) + inputs_ort = { + "input": x_bsd.detach().cpu().numpy(), + "position_ids": pos_ms.detach().cpu().numpy(), + } + + # Compare outputs as BxSxNxH + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.transpose(1, 2).detach().cpu().numpy()) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_rotary_embedding_fusion.py b/onnxruntime/test/python/transformers/test_rotary_embedding_fusion.py new file mode 100644 index 0000000000..7bca48c290 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_rotary_embedding_fusion.py @@ -0,0 +1,447 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import sys +import unittest +from typing import List + +import numpy as np +import onnx +from onnx import TensorProto, helper +from parity_utilities import find_transformers_source + +if find_transformers_source(): + from fusion_options import FusionOptions + from onnx_model import OnnxModel + from optimizer import optimize_model +else: + from onnxruntime.transformers.fusion_options import FusionOptions + from onnxruntime.transformers.onnx_model import OnnxModel + from onnxruntime.transformers.optimizer import optimize_model + + +def float_tensor(name: str, shape: List[int], random=False): + low = 0.0 + high = 1.0 + total_elements = 1 + for x in shape: + total_elements *= x + weights = [np.random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements + return helper.make_tensor(name, TensorProto.FLOAT, shape, weights) + + +class TestRotaryEmbeddingFusion(unittest.TestCase): + def setUp(self): + self.batch_size = 2 + self.sequence_length = 8 + self.num_heads = 4 + self.head_size = 6 + self.hidden_size = self.num_heads * self.head_size + + self.past_sequence_length = 2 + self.max_sequence_length = 12 + + def verify_fusion(self, expected_model_path, original_model_path): + expected_model = OnnxModel(onnx.load(expected_model_path)) + expected_model.topological_sort(is_deterministic=True) + + options = FusionOptions("gpt2") + optimized_model = optimize_model(original_model_path, optimization_options=options, opt_level=0) + optimized_model.topological_sort(is_deterministic=True) + + self.assertTrue(str(expected_model.model.graph), str(optimized_model.model.graph)) + + def create_initializers(self): + initializers = [ + float_tensor("cos_cache", [self.max_sequence_length, self.head_size]), + float_tensor("sin_cache", [self.max_sequence_length, self.head_size]), + helper.make_tensor( + "pos_ids_new_shape", + TensorProto.FLOAT, + [2], + np.array([self.batch_size, self.sequence_length], dtype=np.int64), + ), + helper.make_tensor("zero", TensorProto.FLOAT, [1], np.array([0], dtype=np.int64)), + helper.make_tensor("one", TensorProto.FLOAT, [1], np.array([1], dtype=np.int64)), + helper.make_tensor("two", TensorProto.FLOAT, [1], np.array([2], dtype=np.int64)), + helper.make_tensor("three", TensorProto.FLOAT, [1], np.array([3], dtype=np.int64)), + helper.make_tensor("int_max", TensorProto.FLOAT, [1], np.array([sys.maxsize], dtype=np.int64)), + ] + return initializers + + def create_inputs_and_outputs(self, model_type: str = ""): + inputs = [ + helper.make_tensor_value_info( + "input_0", + TensorProto.FLOAT, + [self.batch_size, self.sequence_length, self.num_heads, self.head_size], + ), + helper.make_tensor_value_info("position_ids", TensorProto.INT64, [self.batch_size, self.sequence_length]), + ] + if model_type in {"past", "merged"}: + # Input will be removed in fused model since it's not used in RotaryEmbedding. + # We create this input so that we can check the `past_seq_len` path during + # RotaryEmbedding fusion. + inputs.append( + helper.make_tensor_value_info( + "past_key", + TensorProto.FLOAT, + [self.batch_size, self.num_heads, self.past_sequence_length, self.head_size], + ) + ) + # Dummy input to test nodes for `curr_seq_len` path + if model_type != "": + inputs.append( + helper.make_tensor_value_info( + "curr_key", + TensorProto.FLOAT, + [self.batch_size, self.sequence_length, self.num_heads, self.head_size], + ) + ) + outputs = [ + helper.make_tensor_value_info( + "output_0", + TensorProto.FLOAT, + [self.batch_size, self.num_heads, self.sequence_length, self.head_size], + ) + ] + if model_type in {"merged"}: + # Dummy output to test that nodes for `past_seq_len` path are not removed for merged model + outputs.append(helper.make_tensor_value_info("past_seq_len_plus_zero", TensorProto.FLOAT, [1])) + return inputs, outputs + + def create_fused_model(self, interleaved: bool, initializers: List[TensorProto]): + inputs, outputs = self.create_inputs_and_outputs() + + rope_node = helper.make_node( + "RotaryEmbedding", + inputs=[inputs[0].name, inputs[1].name, initializers[0].name, initializers[1].name], + outputs=[outputs[0].name], + name="RotaryEmbedding_0", + interleaved=int(interleaved), + ) + + graph = helper.make_graph( + nodes=[rope_node], + name="RotaryEmbedding_Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + opset_import = helper.make_opsetid(domain="com.microsoft", version=1) + model = helper.make_model(graph, opset_imports=[opset_import]) + return model + + def create_cache_path(self, model_type: str, use_redundant_squeeze_ops: bool): + # Create position ids path + reshape_node = helper.make_node( + "Reshape", + inputs=["position_ids", "pos_ids_new_shape"], + outputs=["pos_ids_reshaped"], + name="Reshape_0", + ) + pos_ids_nodes = [reshape_node] + + # Create cos path + cos_init_unsqueeze_node = helper.make_node( + "Unsqueeze", + inputs=["new_seq_len", "zero"], + outputs=["cos_unsqueeze"], + name="Unsqueeze_2", + ) + cos_slice_node = helper.make_node( + "Slice", + inputs=["cos_cache", "zero", "cos_unsqueeze", "two", "one"], + outputs=["cos_sliced"], + name="Slice_2", + ) + cos_nodes = [cos_init_unsqueeze_node, cos_slice_node] + + if use_redundant_squeeze_ops: + # These two nodes are eliminated by this transformers PR: https://github.com/huggingface/transformers/pull/26162 + cos_squeeze_1_node = helper.make_node( + "Squeeze", + inputs=["cos_sliced", "zero"], + outputs=["cos_squeeze_1"], + name="Squeeze_0", + ) + cos_squeeze_2_node = helper.make_node( + "Squeeze", + inputs=["cos_squeeze_1", "zero"], + outputs=["cos_squeeze_2"], + name="Squeeze_1", + ) + cos_nodes.extend([cos_squeeze_1_node, cos_squeeze_2_node]) + + cos_gather_node = helper.make_node( + "Gather", + inputs=["cos_squeeze_2" if use_redundant_squeeze_ops else "cos_sliced", "pos_ids_reshaped"], + outputs=["cos_indexed"], + name="Gather_1", + ) + cos_end_unsqueeze_node = helper.make_node( + "Unsqueeze", + inputs=["cos_indexed", "one"], + outputs=["cos"], + name="Unsqueeze_3", + ) + cos_nodes.extend([cos_gather_node, cos_end_unsqueeze_node]) + + # Create sin path + sin_init_unsqueeze_node = helper.make_node( + "Unsqueeze", + inputs=["new_seq_len", "zero"], + outputs=["sin_unsqueeze"], + name="Unsqueeze_4", + ) + sin_slice_node = helper.make_node( + "Slice", + inputs=["sin_cache", "zero", "sin_unsqueeze", "two", "one"], + outputs=["sin_sliced"], + name="Slice_3", + ) + sin_nodes = [sin_init_unsqueeze_node, sin_slice_node] + + if use_redundant_squeeze_ops: + sin_squeeze_1_node = helper.make_node( + "Squeeze", + inputs=["sin_sliced", "zero"], + outputs=["sin_squeeze_1"], + name="Squeeze_2", + ) + sin_squeeze_2_node = helper.make_node( + "Squeeze", + inputs=["sin_squeeze_1", "zero"], + outputs=["sin_squeeze_2"], + name="Squeeze_3", + ) + sin_nodes.extend([sin_squeeze_1_node, sin_squeeze_2_node]) + + sin_gather_node = helper.make_node( + "Gather", + inputs=["sin_squeeze_2" if use_redundant_squeeze_ops else "sin_sliced", "pos_ids_reshaped"], + outputs=["sin_indexed"], + name="Gather_2", + ) + sin_end_unsqueeze_node = helper.make_node( + "Unsqueeze", + inputs=["sin_indexed", "one"], + outputs=["sin"], + name="Unsqueeze_5", + ) + sin_nodes.extend([sin_gather_node, sin_end_unsqueeze_node]) + + # Create beginning nodes before cos and sin paths + + # Create curr seq len path + curr_transpose_node = helper.make_node( + "Transpose", + inputs=["curr_key"], + outputs=["curr_key_transposed"], + name="Transpose_curr", + perm=[0, 2, 1, 3], + ) + curr_shape_node = helper.make_node( + "Shape", + inputs=["curr_key_transposed"], + outputs=["curr_shape"], + name="Shape_curr", + ) + curr_gather_node = helper.make_node( + "Gather", + inputs=["curr_shape", "two"], + outputs=["curr_seq_len" if model_type in {"past", "merged"} else "new_seq_len"], + name="Gather_curr", + ) + beginning_nodes = [curr_transpose_node, curr_shape_node, curr_gather_node] + + if model_type in {"past", "merged"}: + # Create past seq len path + past_shape_node = helper.make_node( + "Shape", + inputs=["past_key"], + outputs=["past_shape"], + name="Shape_past", + ) + past_gather_node = helper.make_node( + "Gather", + inputs=["past_shape", "two"], + outputs=["past_seq_len"], + name="Gather_past", + ) + add_node = helper.make_node( + "Add", + inputs=["curr_seq_len", "past_seq_len"], + outputs=["new_seq_len"], + name="Add_1", + ) + beginning_nodes.extend([past_shape_node, past_gather_node, add_node]) + + if model_type == "merged": + dummy_node = helper.make_node( + "Add", + inputs=["past_seq_len", "zero"], + outputs=["past_seq_len_plus_zero"], + name="Add_dummy_node", + ) + beginning_nodes.append(dummy_node) + + return pos_ids_nodes + cos_nodes + sin_nodes + beginning_nodes + + def create_apply_rope_path(self): + start_node = helper.make_node( + "Transpose", + inputs=["input_0"], + outputs=["x"], + name="Transpose_0", + perm=[0, 2, 1, 3], + ) + + # Calculate x_half_shape + shape_node = helper.make_node( + "Shape", + inputs=["x"], + outputs=["x_shape"], + name="Shape_0", + ) + gather_node = helper.make_node( + "Gather", + inputs=["x_shape", "three"], + outputs=["x_last_idx_shape"], + name="Gather_0", + axis=0, + ) + div_node = helper.make_node( + "Div", + inputs=["x_last_idx_shape", "two"], + outputs=["x_half_shape"], + name="Div_0", + ) + unsqueeze_0_node = helper.make_node( + "Unsqueeze", + inputs=["x_half_shape", "zero"], + outputs=["x_half_shape_0"], + name="Unsqueeze_0", + ) + unsqueeze_1_node = helper.make_node( + "Unsqueeze", + inputs=["x_half_shape", "zero"], + outputs=["x_half_shape_1"], + name="Unsqueeze_1", + ) + x_half_shape_nodes = [shape_node, gather_node, div_node, unsqueeze_0_node, unsqueeze_1_node] + + # Calculate rotate_half + x1_node = helper.make_node( + "Slice", + inputs=["x", "zero", "x_half_shape_0", "three", "one"], + outputs=["x1"], + name="Slice_0", + ) + x2_node = helper.make_node( + "Slice", + inputs=["x", "x_half_shape_1", "int_max", "three", "one"], + outputs=["x2"], + name="Slice_1", + ) + neg_node = helper.make_node( + "Neg", + inputs=["x2"], + outputs=["x2_neg"], + name="Neg_0", + ) + x_rotate_half_node = helper.make_node( + "Concat", + inputs=["x2_neg", "x1"], + outputs=["x_rotate_half"], + name="Concat_0", + axis=-1, + ) + rotate_half_nodes = [x1_node, x2_node, neg_node, x_rotate_half_node] + + # Calculate x_embed + x_cos_node = helper.make_node( + "Mul", + inputs=["x", "cos"], + outputs=["x_cos"], + name="Mul_0", + ) + x_sin_node = helper.make_node( + "Mul", + inputs=["x_rotate_half", "sin"], + outputs=["x_rotate_half_sin"], + name="Mul_1", + ) + end_node = helper.make_node( + "Add", + inputs=["x_cos", "x_rotate_half_sin"], + outputs=["output_0"], + name="Add_0", + ) + x_embed_nodes = [start_node, x_cos_node, x_sin_node, end_node] + + return x_half_shape_nodes + rotate_half_nodes + x_embed_nodes + + def create_test_model(self, model_type: str, use_redundant_squeeze_ops: bool, initializers: List[TensorProto]): + apply_rope_nodes = self.create_apply_rope_path() + cache_nodes = self.create_cache_path(model_type, use_redundant_squeeze_ops) + inputs, outputs = self.create_inputs_and_outputs(model_type) + + graph = helper.make_graph( + nodes=apply_rope_nodes + cache_nodes, + name="RotaryEmbedding_Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + opset_import = helper.make_opsetid(domain="ai.onnx", version=13) + model = helper.make_model(graph, opset_imports=[opset_import]) + return model + + def check_models(self, interleaved: bool, model_type: str): + initializers = self.create_initializers() + + expected_model_filename = "expected_model.onnx" + expected_model = self.create_fused_model(interleaved, initializers) + onnx.save(expected_model, expected_model_filename) + + original_model_filename = "original_model.onnx" + use_redundant_squeeze_ops = True + original_model = self.create_test_model(model_type, use_redundant_squeeze_ops, initializers) + onnx.save(original_model, original_model_filename) + + self.verify_fusion(expected_model_filename, original_model_filename) + os.remove(original_model_filename) + + use_redundant_squeeze_ops = False + original_model = self.create_test_model(model_type, use_redundant_squeeze_ops, initializers) + onnx.save(original_model, original_model_filename) + + self.verify_fusion(expected_model_filename, original_model_filename) + os.remove(expected_model_filename) + os.remove(original_model_filename) + + # Hugging Face's `decoder_model.onnx` + def test_hf_decoder_model(self): + interleaved = False # HF model does not use interleaving + model_type = "no_past" + self.check_models(interleaved, model_type) + + # Hugging Face's `decoder_with_past_model.onnx` + def test_hf_decoder_with_past_model(self): + interleaved = False # HF model does not use interleaving + model_type = "past" + self.check_models(interleaved, model_type) + + # Hugging Face's `decoder_merged.onnx` + def test_hf_decoder_merged_model(self): + interleaved = False # HF model does not use interleaving + model_type = "merged" + self.check_models(interleaved, model_type) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py new file mode 100644 index 0000000000..373ad86ced --- /dev/null +++ b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py @@ -0,0 +1,1210 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import sys +import unittest +from typing import List + +import numpy as np +import onnx +from onnx import NodeProto, TensorProto, helper +from parity_utilities import find_transformers_source + +if find_transformers_source(): + from fusion_options import FusionOptions + from onnx_model import OnnxModel + from optimizer import optimize_model +else: + from onnxruntime.transformers.fusion_options import FusionOptions + from onnxruntime.transformers.onnx_model import OnnxModel + from onnxruntime.transformers.optimizer import optimize_model + + +def float_tensor(name: str, shape: List[int], random=False): + low = 0.0 + high = 1.0 + total_elements = 1 + for x in shape: + total_elements *= x + weights = [np.random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements + return helper.make_tensor(name, TensorProto.FLOAT, shape, weights) + + +class TestRotaryAttentionFusion(unittest.TestCase): + def setUp(self): + self.batch_size = 2 + self.sequence_length = 8 + self.num_heads = 4 + self.head_size = 6 + self.hidden_size = self.num_heads * self.head_size + + self.past_sequence_length = 2 + self.max_sequence_length = 12 + + def verify_fusion(self, expected_model_path, original_model_path): + expected_model = OnnxModel(onnx.load(expected_model_path)) + expected_model.topological_sort(is_deterministic=True) + + model_type = "gpt2" + options = FusionOptions(model_type) + optimized_model = optimize_model( + original_model_path, + model_type, + self.num_heads, + self.hidden_size, + optimization_options=options, + opt_level=0, + ) + optimized_model.topological_sort(is_deterministic=True) + + self.assertTrue(str(expected_model.model.graph), str(optimized_model.model.graph)) + + def create_initializers(self, fused_model: bool = False): + initializers = [ + float_tensor("cos_cache", [self.max_sequence_length, self.head_size // 2]), + float_tensor("sin_cache", [self.max_sequence_length, self.head_size // 2]), + float_tensor("q_weight", [self.hidden_size, self.hidden_size]), + float_tensor("k_weight", [self.hidden_size, self.hidden_size]), + float_tensor("v_weight", [self.hidden_size, self.hidden_size]), + float_tensor("o_weight", [self.hidden_size, self.hidden_size]), + helper.make_tensor( + "sqrt_head_size", TensorProto.FLOAT, [1], np.array([np.sqrt(self.head_size)], dtype=np.float32) + ), + helper.make_tensor("neg_int_max", TensorProto.FLOAT, [1], np.array([-sys.maxsize - 1], dtype=np.int64)), + helper.make_tensor("num_heads", TensorProto.FLOAT, [1], np.array([self.num_heads], dtype=np.float32)), + helper.make_tensor("head_size", TensorProto.FLOAT, [1], np.array([self.head_size], dtype=np.float32)), + helper.make_tensor("hidden_size", TensorProto.FLOAT, [1], np.array([self.hidden_size], dtype=np.float32)), + helper.make_tensor("zero", TensorProto.FLOAT, [1], np.array([0], dtype=np.int64)), + helper.make_tensor("one", TensorProto.FLOAT, [1], np.array([1], dtype=np.int64)), + helper.make_tensor("two", TensorProto.FLOAT, [1], np.array([2], dtype=np.int64)), + helper.make_tensor("three", TensorProto.FLOAT, [1], np.array([3], dtype=np.int64)), + ] + return initializers + + def create_inputs_and_outputs(self, model_type: str): + attn_mask_size = [self.batch_size, self.sequence_length] + if model_type == "llama2_msft": + attn_mask_size.append(self.sequence_length) + + inputs = [ + helper.make_tensor_value_info( + "input_0", TensorProto.FLOAT, [self.batch_size, self.sequence_length, self.hidden_size] + ), + helper.make_tensor_value_info("position_ids", TensorProto.INT64, [self.batch_size, self.sequence_length]), + helper.make_tensor_value_info("attn_mask", TensorProto.INT64, attn_mask_size), + ] + if model_type in {"past", "merged", "llama2_msft", "70b_distributed_merged"}: + inputs.extend( + [ + helper.make_tensor_value_info( + "past_key", + TensorProto.FLOAT, + [self.batch_size, self.num_heads, self.past_sequence_length, self.head_size], + ), + helper.make_tensor_value_info( + "past_value", + TensorProto.FLOAT, + [self.batch_size, self.num_heads, self.past_sequence_length, self.head_size], + ), + ] + ) + outputs = [ + helper.make_tensor_value_info( + "output_0", TensorProto.FLOAT, [self.batch_size, self.sequence_length, self.hidden_size] + ), + helper.make_tensor_value_info( + "present_key", + TensorProto.FLOAT, + [self.batch_size, self.num_heads, self.past_sequence_length + 1, self.head_size], + ), + helper.make_tensor_value_info( + "present_value", + TensorProto.FLOAT, + [self.batch_size, self.num_heads, self.past_sequence_length + 1, self.head_size], + ), + ] + return inputs, outputs + + def create_matmul_nodes(self, is_fused: bool, model_type: str): + q_matmul_node = helper.make_node( + "MatMul", + inputs=["input_0", "q_weight"], + outputs=["q_out" if is_fused or model_type == "llama2_msft" else "q_matmul_out"], + name="Q_MatMul", + ) + + k_matmul_node = helper.make_node( + "MatMul", + inputs=["input_0", "k_weight"], + outputs=["k_out" if is_fused or model_type == "llama2_msft" else "k_matmul_out"], + name="K_MatMul", + ) + + v_matmul_node = helper.make_node( + "MatMul", + inputs=["input_0", "v_weight"], + outputs=["v_out"], + name="V_MatMul", + ) + + return [q_matmul_node, k_matmul_node, v_matmul_node] + + def create_rotary_embeddings( + self, + is_fused: bool, + model_type: str, + interleaved: bool, + inputs: List[TensorProto], + initializers: List[TensorProto], + ): + def get_first_rope_input(node_type: str): + if is_fused or model_type == "llama2_msft": + # q_out/k_out + return f"{node_type}_out" + if model_type in {"no_past", "past", "merged", "70b_distributed_merged"}: + if node_type == "k": + return "k_before_rope" + return "q_before_rope" + return "" + + def get_first_rope_output(node_type: str): + if is_fused or model_type in {"llama2_msft", "past", "merged", "70b_distributed_merged"}: + if node_type == "q": + return "q_rope" + return "k_rope" + if model_type in {"no_past"}: + if node_type == "k": + return "present_key" + return "q_rope" + return "" + + q_rope_node = helper.make_node( + "RotaryEmbedding", + inputs=[get_first_rope_input("q"), inputs[1].name, initializers[0].name, initializers[1].name], + outputs=[get_first_rope_output("q")], + name="Q_RotaryEmbedding", + interleaved=int(interleaved), + ) + + k_rope_node = helper.make_node( + "RotaryEmbedding", + inputs=[get_first_rope_input("k"), inputs[1].name, initializers[0].name, initializers[1].name], + outputs=[get_first_rope_output("k")], + name="K_RotaryEmbedding", + interleaved=int(interleaved), + ) + + return [q_rope_node, k_rope_node] + + def create_q_path(self, model_type: str): + if model_type == "llama2_msft": + transpose_q_node = helper.make_node( + "Transpose", + inputs=["q_rope"], + outputs=["q_transposed"], + name="Transpose_q", + perm=[0, 2, 1, 3], + ) + reshape_q_node = helper.make_node( + "Reshape", + inputs=["q_transposed", "concat_q_extra_out"], + outputs=["q"], + name="Reshape_q", + ) + return [transpose_q_node, reshape_q_node] + + reshape_q_node = helper.make_node( + "Reshape", + inputs=["q_matmul_out", "concat_q_extra_out"], + outputs=["q_reshaped"], + name="Reshape_q", + ) + transpose_q_node = helper.make_node( + "Transpose", + inputs=["q_reshaped"], + outputs=["q_before_rope"], + name="Transpose_q", + ) + return [reshape_q_node, transpose_q_node] + + def create_k_path_llama2_msft(self): + # Create k cache slicing path + k_cache_unsqueeze_node = helper.make_node( + "Unsqueeze", + inputs=["position_ids", "zero"], + outputs=["k_pos_id"], + ) + k_cache_slice_node = helper.make_node( + "Slice", + inputs=["past_key", "zero", "k_pos_id", "two", "one"], + outputs=["k_cache_sliced"], + ) + # Create k path + transpose_k_1_node = helper.make_node( + "Transpose", + inputs=["k_rope"], + outputs=["k_rope_transposed"], + name="Transpose_k_1", + perm=[0, 2, 1, 3], + ) + concat_k_node = helper.make_node( + "Concat", + inputs=["k_cache_sliced", "k_rope_transposed"], + outputs=["present_key"], + name="Concat_k", + axis=2, + ) + transpose_k_2_node = helper.make_node( + "Transpose", + inputs=["present_key"], + outputs=["present_key_transposed"], + name="Transpose_k_2", + perm=[0, 2, 3, 1], + ) + reshape_k_node = helper.make_node( + "Reshape", + inputs=["present_key_transposed", "concat_k_extra_out"], + outputs=["k"], + name="Reshape_k", + ) + return [ + k_cache_unsqueeze_node, + k_cache_slice_node, + transpose_k_1_node, + concat_k_node, + transpose_k_2_node, + reshape_k_node, + ] + + def create_k_path_hf(self, model_type: str): + reshape_k_node = helper.make_node( + "Reshape", + inputs=["k_matmul_out", "concat_k_extra_out"], + outputs=["k_reshaped"], + name="Reshape_k", + ) + transpose_k_1_node = helper.make_node( + "Transpose", + inputs=["k_reshaped"], + outputs=["k_before_rope"], + name="Transpose_k_1", + perm=[0, 2, 1, 3], + ) + k_nodes = [reshape_k_node, transpose_k_1_node] + + if model_type == "70b_distributed_merged": + concat_k_node = helper.make_node( + "Concat", + inputs=["past_key", "k_rope"], + outputs=["present_key"], + axis=2, + ) + shape_k1 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k1_out"], name="Shape_k1") + shape_k2 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k2_out"], name="Shape_k2") + shape_k3 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k3_out"], name="Shape_k3") + shape_k4 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k4_out"], name="Shape_k4") + + gather_k_1 = helper.make_node( + "Gather", + inputs=["shape_k1_out", "one"], + outputs=["gather_k1_out"], + name="Gather_k_1", + axis=0, + ) + gather_k_2 = helper.make_node( + "Gather", + inputs=["shape_k2_out", "one"], + outputs=["gather_k2_out"], + name="Gather_k_2", + axis=0, + ) + gather_k_3 = helper.make_node( + "Gather", + inputs=["shape_k3_out", "one"], + outputs=["gather_k3_out"], + name="Gather_k_3", + axis=0, + ) + gather_k_4 = helper.make_node( + "Gather", + inputs=["shape_k4_out", "one"], + outputs=["gather_k4_out"], + name="Gather_k_4", + axis=0, + ) + + unsqueeze_k_1 = helper.make_node( + "Unsqueeze", + inputs=["present_value", "zero"], + outputs=["unsqueeze_k1_out"], + name="Unsqueeze_k1", + ) + unsqueeze_k_2 = helper.make_node( + "Unsqueeze", + inputs=["gather_k1_out", "zero"], + outputs=["unsqueeze_k2_out"], + name="Unsqueeze_k2", + ) + unsqueeze_k_3 = helper.make_node( + "Unsqueeze", + inputs=["gather_k2_out", "zero"], + outputs=["unsqueeze_k3_out"], + name="Unsqueeze_k3", + ) + unsqueeze_k_4 = helper.make_node( + "Unsqueeze", + inputs=["gather_k3_out", "zero"], + outputs=["unsqueeze_k4_out"], + name="Unsqueeze_k4", + ) + unsqueeze_k_5 = helper.make_node( + "Unsqueeze", + inputs=["gather_k4_out", "zero"], + outputs=["unsqueeze_k5_out"], + name="Unsqueeze_k5", + ) + + concat_k_2 = helper.make_node( + "Concat", + inputs=["unsqueeze_k2_out", "unsqueeze_k3_out", "One", "unsqueeze_k4_out", "unsqueeze_k5_out"], + outputs=["concat_k2_ouot"], + name="Concat_k2", + axis=0, + ) + reshape_k_2 = helper.make_node( + "Reshape", + inputs=["concat_k2_ouot", "One"], + outputs=["reshape_k2_out"], + name="Reshape_k_2", + ) + shape_k5 = helper.make_node("Shape", inputs=["reshape_k2_out"], outputs=["shape_k5_out"], name="Shape_k5") + constant_of_shape_k_1 = helper.make_node( + "ConstantOfShape", + inputs=["shape_k5_out"], + outputs=["constant_of_shape_k1_out"], + name="ConstantOfShape_k1", + ) + mul_k_1 = helper.make_node( + "Mul", + inputs=["constant_of_shape_k1_out", "One"], + outputs=["mul_k1_out"], + name="mul_k1", + ) + equal_k_1 = helper.make_node( + "Equal", + inputs=["reshape_k2_out", "mul_k1_out"], + outputs=["equal_k_1_out"], + name="equal_k1", + ) + where_k_1 = helper.make_node( + "Where", + inputs=["equal_k_1_out", "constant_of_shape_k1_out", "reshape_k2_out"], + outputs=["where_k_1_out"], + name="where_k1", + ) + unsqueeze_k_6 = helper.make_node( + "Unsqueeze", + inputs=["gather_k1_out", "zero"], + outputs=["unsqueeze_k6_out"], + name="Unsqueeze_k6", + ) + mul_k_2 = helper.make_node( + "Mul", + inputs=["gather_k2_out", "One"], + outputs=["mul_k2_out"], + name="mul_k2", + ) + unsqueeze_k_7 = helper.make_node( + "Unsqueeze", + inputs=["mul_k2_out", "zero"], + outputs=["unsqueeze_k7_out"], + name="Unsqueeze_k7", + ) + unsqueeze_k_8 = helper.make_node( + "Unsqueeze", + inputs=["gather_k3_out", "zero"], + outputs=["unsqueeze_k8_out"], + name="Unsqueeze_k8", + ) + unsqueeze_k_9 = helper.make_node( + "Unsqueeze", + inputs=["gather_k4_out", "zero"], + outputs=["unsqueeze_k9_out"], + name="Unsqueeze_k9", + ) + concat_k_3 = helper.make_node( + "Concat", + inputs=["unsqueeze_k6_out", "unsqueeze_k7_out", "unsqueeze_k8_out", "unsqueeze_k9_out"], + outputs=["concat_k3_out"], + name="Concat_k3", + axis=0, + ) + expand_k_1 = helper.make_node( + "Expand", + inputs=["unsqueeze_k1_out", "where_k_1_out"], + outputs=["expand_k1_out"], + name="expand_k1", + ) + reshape_k_3 = helper.make_node( + "Reshape", + inputs=["expand_k1_out", "concat_k3_out"], + outputs=["reshape_k3_out"], + name="Reshape_k_3", + ) + transpose_k_2_node = helper.make_node( + "Transpose", + inputs=["reshape_k3_out"], + outputs=["k"], + name="Transpose_k_2", + perm=[0, 1, 3, 2], + ) + + k_nodes_for_70b_model = [ + concat_k_node, + shape_k1, + shape_k2, + shape_k3, + shape_k4, + gather_k_1, + gather_k_2, + gather_k_3, + gather_k_4, + unsqueeze_k_1, + unsqueeze_k_2, + unsqueeze_k_3, + unsqueeze_k_4, + unsqueeze_k_5, + concat_k_2, + reshape_k_2, + shape_k5, + constant_of_shape_k_1, + mul_k_1, + equal_k_1, + where_k_1, + unsqueeze_k_6, + mul_k_2, + unsqueeze_k_7, + unsqueeze_k_8, + unsqueeze_k_9, + concat_k_3, + expand_k_1, + reshape_k_3, + transpose_k_2_node, + ] + k_nodes.extend(k_nodes_for_70b_model) + return k_nodes + else: + if model_type in {"past", "merged"}: + concat_k_node = helper.make_node( + "Concat", + inputs=["past_key", "k_rope"], + outputs=["present_key"], + axis=2, + ) + k_nodes.append(concat_k_node) + + transpose_k_2_node = helper.make_node( + "Transpose", + inputs=["present_key"], + outputs=["k"], + name="Transpose_k_2", + perm=[0, 1, 3, 2], + ) + return k_nodes + [transpose_k_2_node] # noqa: RUF005 + + def create_k_path(self, model_type: str): + if model_type == "llama2_msft": + return self.create_k_path_llama2_msft() + return self.create_k_path_hf(model_type) + + def create_attn_mask_path_llama2_msft(self): + x_shape_node = helper.make_node( + "Shape", + inputs=["input_0"], + outputs=["input_0_shape"], + name="Shape_input", + ) + x_get_seq_len_node = helper.make_node( + "Gather", + inputs=["input_0_shape", "one"], + outputs=["input_0_seq_len"], + name="Gather_input", + axis=0, + ) + x_new_seq_len_node = helper.make_node( + "Add", + inputs=["position_ids", "input_0_seq_len"], + outputs=["new_seq_len"], + name="Add_mask", + ) + unsqueeze_0_node = helper.make_node( + "Unsqueeze", + inputs=["position_ids", "zero"], + outputs=["unsqueeze_mask_0_out"], + name="Unsqueeze_mask_0", + ) + unsqueeze_1_node = helper.make_node( + "Unsqueeze", + inputs=["new_seq_len", "zero"], + outputs=["unsqueeze_mask_1_out"], + name="Unsqueeze_mask_1", + ) + unsqueeze_2_node = helper.make_node( + "Unsqueeze", + inputs=["new_seq_len", "zero"], + outputs=["unsqueeze_mask_2_out"], + name="Unsqueeze_mask_2", + ) + slice_mask_1_node = helper.make_node( + "Slice", + inputs=["attn_mask", "unsqueeze_mask_0_out", "unsqueeze_mask_1_out", "one", "one"], + outputs=["slice_mask_1_out"], + name="Slice_mask_1", + ) + slice_mask_2_node = helper.make_node( + "Slice", + inputs=["slice_mask_1_out", "zero", "unsqueeze_mask_2_out", "two", "one"], + outputs=["slice_mask_2_out"], + name="Slice_mask_2", + ) + concat_mask_node = helper.make_node( + "Concat", + inputs=["slice_mask_2_out" for _ in range(self.num_heads)], + outputs=["attn_mask_out"], + name="Concat_mask", + axis=0, + ) + return [ + x_shape_node, + x_get_seq_len_node, + x_new_seq_len_node, + unsqueeze_0_node, + unsqueeze_1_node, + unsqueeze_2_node, + slice_mask_1_node, + slice_mask_2_node, + concat_mask_node, + ] + + def create_attn_mask_path_hf(self, model_type: str): + unsqueeze_1_node = helper.make_node( + "Unsqueeze", + inputs=["attn_mask", "one"], + outputs=["unsqueeze_1_mask_out"], + name="Unsqueeze_1_mask", + ) + unsqueeze_2_node = helper.make_node( + "Unsqueeze", + inputs=["unsqueeze_1_mask_out", "two"], + outputs=["unsqueeze_2_mask_out"], + name="Unsqueeze_2_mask", + ) + expand_node = helper.make_node( + "Expand", + inputs=["unsqueeze_2_mask_out", "zero"], + outputs=["expand_out"], + name="Expand_mask", + ) + cast_node = helper.make_node( + "Cast", + inputs=["expand_out"], + outputs=["cast_out"], + name="Cast_mask", + to=TensorProto.FLOAT, + ) + sub_node = helper.make_node( + "Sub", + inputs=["one", "cast_out"], + outputs=["sub_out"], + name="Sub_mask", + ) + where_node = helper.make_node( + "Where", + inputs=["zero", "neg_int_max", "sub_out"], + outputs=["where_out" if model_type != "past" else "attn_mask_out"], + name="Where_mask", + ) + attn_mask_nodes = [unsqueeze_1_node, unsqueeze_2_node, expand_node, cast_node, sub_node, where_node] + + if model_type == "past": + return attn_mask_nodes + + add_node = helper.make_node( + "Add", + inputs=["where_out", "zero"], + outputs=["attn_mask_out"], + name="Add_mask", + ) + return attn_mask_nodes + [add_node] # noqa: RUF005 + + def create_attn_mask_path(self, is_fused: bool, model_type: str): + if model_type == "llama2_msft": + attn_mask_nodes = self.create_attn_mask_path_llama2_msft() + if is_fused: + attn_mask_nodes.pop() + attn_mask_nodes[-1].output[0] = "attn_mask_out" + return attn_mask_nodes + + attn_mask_nodes = self.create_attn_mask_path_hf(model_type) + if is_fused: + new_output_name = "attn_mask_out_mask" + attn_mask_nodes[-1].output[0] = new_output_name + concat_mask_node = helper.make_node( + "Concat", + inputs=[new_output_name for _ in range(self.num_heads)], + outputs=["attn_mask_out"], + name="Concat_mask", + axis=0, + ) + attn_mask_nodes.append(concat_mask_node) + return attn_mask_nodes + + def create_qk_path(self, model_type: str): + matmul_qk_node = helper.make_node( + "MatMul", + inputs=["q" if model_type == "llama2_msft" else "q_rope", "k"], + outputs=["qk"], + name="MatMul_q_k", + ) + div_node = helper.make_node( + "Div", + inputs=["qk", "sqrt_head_size"], + outputs=["qk_div"], + name="Div_0", + ) + add_node = helper.make_node( + "Add", + inputs=["qk_div", "attn_mask_out"], + outputs=["qk_plus_mask"], + name="Add_0", + ) + softmax_node = helper.make_node( + "Softmax", + inputs=["qk_plus_mask"], + outputs=["softmax_out"], + name="Softmax_0", + ) + return [matmul_qk_node, div_node, add_node, softmax_node] + + def create_v_path(self, model_type: str): + reshape_v_1_node = helper.make_node( + "Reshape", + inputs=["v_out", "concat_v_1_extra_out"], + outputs=["reshape_v_1_out"], + name="Reshape_v_1", + ) + transpose_v_1_node = helper.make_node( + "Transpose", + inputs=["reshape_v_1_out"], + outputs=["transpose_v_1_out" if model_type != "no_past" else "present_value"], + name="Transpose_v_1", + ) + v_nodes = [reshape_v_1_node, transpose_v_1_node] + + if model_type == "no_past": + return v_nodes + + if model_type in {"past", "merged", "70b_distributed_merged"}: + concat_v_node = helper.make_node( + "Concat", + inputs=["past_value", "transpose_v_1_out"], + outputs=["present_value"], + name="Concat_v", + axis=2, + ) + + if model_type != "70b_distributed_merged": + return v_nodes + [concat_v_node] # noqa: RUF005 + + shape_v1 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_1_out"], name="Shape_v1") + shape_v2 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_2_out"], name="Shape_v2") + shape_v3 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_3_out"], name="Shape_v3") + shape_v4 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_4_out"], name="Shape_v4") + gather_v_1 = helper.make_node( + "Gather", + inputs=["shape_1_out", "one"], + outputs=["gather_1_out"], + name="Gather_v1", + axis=0, + ) + gather_v_2 = helper.make_node( + "Gather", + inputs=["shape_2_out", "one"], + outputs=["gather_2_out"], + name="Gather_v2", + axis=0, + ) + gather_v_3 = helper.make_node( + "Gather", + inputs=["shape_3_out", "one"], + outputs=["gather_3_out"], + name="Gather_v3", + axis=0, + ) + gather_v_4 = helper.make_node( + "Gather", + inputs=["shape_4_out", "one"], + outputs=["gather_4_out"], + name="Gather_v4", + axis=0, + ) + unsqueeze_v_1 = helper.make_node( + "Unsqueeze", + inputs=["present_value", "zero"], + outputs=["unsqueeze_v1_out"], + name="Unsqueeze_v1", + ) + unsqueeze_v_2 = helper.make_node( + "Unsqueeze", + inputs=["gather_1_out", "zero"], + outputs=["unsqueeze_v2_out"], + name="Unsqueeze_v2", + ) + unsqueeze_v_3 = helper.make_node( + "Unsqueeze", + inputs=["gather_2_out", "zero"], + outputs=["unsqueeze_v3_out"], + name="Unsqueeze_v3", + ) + unsqueeze_v_4 = helper.make_node( + "Unsqueeze", + inputs=["gather_3_out", "zero"], + outputs=["unsqueeze_v4_out"], + name="Unsqueeze_v4", + ) + unsqueeze_v_5 = helper.make_node( + "Unsqueeze", + inputs=["gather_4_out", "zero"], + outputs=["unsqueeze_v5_out"], + name="Unsqueeze_v5", + ) + concat_v_2 = helper.make_node( + "Concat", + inputs=["unsqueeze_v2_out", "unsqueeze_v3_out", "One", "unsqueeze_v4_out", "unsqueeze_v5_out"], + outputs=["concat_v2_ouot"], + name="Concat_v2", + axis=0, + ) + reshape_v_2 = helper.make_node( + "Reshape", + inputs=["concat_v2_ouot", "One"], + outputs=["reshape_v2_out"], + name="Reshape_v2", + ) + shape_v5 = helper.make_node("Shape", inputs=["reshape_v2_out"], outputs=["shape_5_out"], name="Shape_v5") + constant_of_shape_v_1 = helper.make_node( + "ConstantOfShape", + inputs=["shape_5_out"], + outputs=["constant_of_shape_v1_out"], + name="ConstantOfShape_v1", + ) + mul_v_1 = helper.make_node( + "Mul", + inputs=["constant_of_shape_v1_out", "One"], + outputs=["mul_v1_out"], + name="mul_v1", + ) + equal_v_1 = helper.make_node( + "Equal", + inputs=["reshape_v2_out", "mul_v1_out"], + outputs=["equal_v_1_out"], + name="equal_v1", + ) + where_v_1 = helper.make_node( + "Where", + inputs=["equal_v_1_out", "constant_of_shape_v1_out", "reshape_v2_out"], + outputs=["where_v_1_out"], + name="where_v1", + ) + unsqueeze_v_6 = helper.make_node( + "Unsqueeze", + inputs=["gather_1_out", "zero"], + outputs=["unsqueeze_v6_out"], + name="Unsqueeze_v6", + ) + mul_v_2 = helper.make_node( + "Mul", + inputs=["gather_2_out", "One"], + outputs=["mul_v2_out"], + name="mul_v2", + ) + unsqueeze_v_7 = helper.make_node( + "Unsqueeze", + inputs=["mul_v2_out", "zero"], + outputs=["unsqueeze_v7_out"], + name="Unsqueeze_v7", + ) + unsqueeze_v_8 = helper.make_node( + "Unsqueeze", + inputs=["gather_3_out", "zero"], + outputs=["unsqueeze_v8_out"], + name="Unsqueeze_v8", + ) + unsqueeze_v_9 = helper.make_node( + "Unsqueeze", + inputs=["gather_4_out", "zero"], + outputs=["unsqueeze_v9_out"], + name="Unsqueeze_v9", + ) + concat_v_3 = helper.make_node( + "Concat", + inputs=["unsqueeze_v6_out", "unsqueeze_v7_out", "unsqueeze_v8_out", "unsqueeze_v9_out"], + outputs=["concat_v3_out"], + name="Concat_v3", + axis=0, + ) + expand_v_1 = helper.make_node( + "Expand", + inputs=["unsqueeze_v1_out", "where_v_1_out"], + outputs=["expand_v1_out"], + name="expand_v1", + ) + reshape_v_3 = helper.make_node( + "Reshape", + inputs=["expand_v1_out", "concat_v3_out"], + outputs=["reshape_v3_out"], + name="Reshape_v3", + ) + + v_nodes_for_70b_model = [ + concat_v_node, + shape_v1, + shape_v2, + shape_v3, + shape_v4, + gather_v_1, + gather_v_2, + gather_v_3, + gather_v_4, + unsqueeze_v_1, + unsqueeze_v_2, + unsqueeze_v_3, + unsqueeze_v_4, + unsqueeze_v_5, + concat_v_2, + reshape_v_2, + shape_v5, + constant_of_shape_v_1, + mul_v_1, + equal_v_1, + where_v_1, + unsqueeze_v_6, + mul_v_2, + unsqueeze_v_7, + unsqueeze_v_8, + unsqueeze_v_9, + concat_v_3, + expand_v_1, + reshape_v_3, + ] + v_nodes.extend(v_nodes_for_70b_model) + + return v_nodes + + # Create extra nodes for `position_ids` + unsqueeze_v_node = helper.make_node( + "Unsqueeze", + inputs=["position_ids", "zero"], + outputs=["unsqueeze_v_out"], + name="Unsqueeze_v", + ) + slice_v_node = helper.make_node( + "Slice", + inputs=["past_value", "zero", "unsqueeze_v_out", "two", "one"], + outputs=["v_cache_sliced_out"], + name="Slice_v", + ) + concat_v_node = helper.make_node( + "Concat", + inputs=["v_cache_sliced_out", "transpose_v_1_out"], + outputs=["present_value"], + name="Concat_v", + axis=2, + ) + v_nodes.extend([unsqueeze_v_node, slice_v_node, concat_v_node]) + + # Create remaining nodes for v path + transpose_v_2_node = helper.make_node( + "Transpose", + inputs=["present_value"], + outputs=["transpose_v_2_out"], + name="Transpose_v_2", + ) + reshape_v_2_node = helper.make_node( + "Reshape", + inputs=["transpose_v_2_out", "concat_v_2_extra_out"], + outputs=["v"], + name="Reshape_v_2", + ) + return v_nodes + [transpose_v_2_node, reshape_v_2_node] # noqa: RUF005 + + def create_qkv_path(self, model_type: str): + matmul_qkv_node = helper.make_node( + "MatMul", + inputs=["softmax_out", "v" if model_type == "llama2_msft" else "present_value"], + outputs=["softmax_v_out"], + name="MatMul_softmax_v", + ) + qkv_nodes = [matmul_qkv_node] + + if model_type == "llama2_msft": + reshape_qkv_1_node = helper.make_node( + "Reshape", + inputs=["softmax_v_out", "concat_qkv_1_extra_out"], + outputs=["reshape_qkv_1_out"], + name="Reshape_qkv_1", + ) + qkv_nodes.append(reshape_qkv_1_node) + + transpose_qkv_node = helper.make_node( + "Transpose", + inputs=["reshape_qkv_1_out" if model_type == "llama2_msft" else "softmax_v_out"], + outputs=["transpose_qkv_out"], + name="Transpose_qkv", + ) + reshape_qkv_2_node = helper.make_node( + "Reshape", + inputs=["transpose_qkv_out", "concat_qkv_2_extra_out"], + outputs=["attn_output"], + name="Reshape_qkv_2", + ) + + return qkv_nodes + [transpose_qkv_node, reshape_qkv_2_node] # noqa: RUF005 + + def create_concat_unsqueeze_paths(self, model_type: str, reshape_nodes: List[NodeProto]): + # Create initial shape paths + shape_0_node = helper.make_node( + "Shape", + inputs=["input_0"], + outputs=["input_0_shape_0"], + name="Shape_0", + ) + gather_0_node = helper.make_node( + "Gather", + inputs=["input_0_shape_0", "zero"], + outputs=["input_0_shape_0_indexed"], + name="Gather_0", + axis=0, + ) + shape_1_node = helper.make_node( + "Shape", + inputs=["input_0"], + outputs=["input_0_shape_1"], + name="Shape_1", + ) + gather_1_node = helper.make_node( + "Gather", + inputs=["input_0_shape_1", "one"], + outputs=["input_0_shape_1_indexed"], + name="Gather_1", + axis=0, + ) + extra_nodes = [shape_0_node, gather_0_node, shape_1_node, gather_1_node] + + if model_type == "llama2_msft": + mul_node = helper.make_node( + "Mul", + inputs=[gather_0_node.output[0], "num_heads"], + outputs=["mul_extra_out"], + name="Mul_extra_0", + ) + add_node = helper.make_node( + "Add", + inputs=[gather_1_node.output[0], "position_ids"], + outputs=["add_extra_out"], + name="Add_extra_0", + ) + extra_nodes.extend([mul_node, add_node]) + + for i, reshape_node in enumerate(reshape_nodes): + use_mul_and_add_nodes_0 = model_type == "llama2_msft" and reshape_node.output[0] in {"q", "k", "v"} + use_mul_and_add_nodes_1 = model_type == "llama2_msft" and reshape_node.output[0] in {"k", "v"} + + unsqueeze_0_node = helper.make_node( + "Unsqueeze", + inputs=[gather_0_node.output[0] if not use_mul_and_add_nodes_0 else "mul_extra_out", "zero"], + outputs=[f"unsqueeze_extra_{2*i}"], + name=f"Unsqueeze_extra_{2*i}", + ) + unsqueeze_1_node = helper.make_node( + "Unsqueeze", + inputs=[gather_1_node.output[0] if not use_mul_and_add_nodes_1 else "add_extra_out", "zero"], + outputs=[f"unsqueeze_extra_{2*i + 1}"], + name=f"Unsqueeze_extra_{2*i + 1}", + ) + + reshape_name = reshape_node.name + if reshape_name == "Reshape_qkv_2": + concat_node_inputs = [unsqueeze_0_node.output[0], unsqueeze_1_node.output[0], "hidden_size"] + elif reshape_name == "Reshape_qkv_1": + concat_node_inputs = [unsqueeze_0_node.output[0], "num_heads", unsqueeze_1_node.output[0], "head_size"] + elif reshape_name == "Reshape_v_2": + concat_node_inputs = [unsqueeze_0_node.output[0], unsqueeze_1_node.output[0], "head_size"] + elif reshape_name == "Reshape_v_1": + concat_node_inputs = [unsqueeze_0_node.output[0], unsqueeze_1_node.output[0], "num_heads", "head_size"] + elif reshape_name == "Reshape_k": + concat_node_inputs = [unsqueeze_0_node.output[0], "head_size", unsqueeze_1_node.output[0]] + elif reshape_name == "Reshape_q": + concat_node_inputs = [unsqueeze_0_node.output[0], unsqueeze_1_node.output[0], "head_size"] + + concat_node = helper.make_node( + "Concat", + inputs=concat_node_inputs, + outputs=[reshape_nodes[i].input[1]], + name=f"Concat_extra_{i}", + axis=0, + ) + extra_nodes.extend([unsqueeze_0_node, unsqueeze_1_node, concat_node]) + + return extra_nodes + + def create_end_nodes(self, model_type): + if model_type == "70b_distributed_merged": + matmul_o_node = helper.make_node( + "MatMul", + inputs=["attn_output", "o_weight"], + outputs=["output_proj"], + name="MatMul_o_proj", + ) + all_reduce = helper.make_node( + "AllReduce", + inputs=["output_proj"], + outputs=["allreduce_proj"], + name="allreduce_proj", + ) + end_node = helper.make_node( + "Add", + inputs=["zero", "allreduce_proj"], + outputs=["output_0"], + name="Add_normalize_node", + ) + return [matmul_o_node, all_reduce, end_node] + + matmul_o_node = helper.make_node( + "MatMul", + inputs=["attn_output", "o_weight"], + outputs=["output_proj"], + name="MatMul_o_proj", + ) + end_node = helper.make_node( + "Add", + inputs=["zero", "output_proj"], + outputs=["output_0"], + name="Add_normalize_node", + ) + return [matmul_o_node, end_node] + + def create_fused_model(self, model_type: str, interleaved: bool, initializers: List[TensorProto]): + inputs, outputs = self.create_inputs_and_outputs(model_type) + matmul_nodes = self.create_matmul_nodes(True, model_type=model_type) + rope_nodes = self.create_rotary_embeddings(True, model_type, interleaved, inputs, initializers) + attn_mask_nodes = self.create_attn_mask_path(True, model_type) + + mha_inputs = [ + rope_nodes[0].output[0], # q + rope_nodes[1].output[0], # k + matmul_nodes[-1].output[0], # v + "", # bias + "attn_mask_out" if model_type == "llama2_msft" else "", # attn_mask + "attn_mask_out" if model_type != "llama2_msft" else "", # add_qk + "past_key" if model_type != "no_past" else "", # past_key + "past_value" if model_type != "no_past" else "", # past_value + ] + mha_node = helper.make_node( + "MultiHeadAttention", + inputs=mha_inputs, + outputs=["attn_output", "present_key", "present_value"], + name="MultiHeadAttention_0", + num_heads=self.num_heads, + ) + + end_nodes = self.create_end_nodes(model_type) + + graph = helper.make_graph( + nodes=matmul_nodes + rope_nodes + attn_mask_nodes + [mha_node] + end_nodes, + name="RotaryAttention_Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + opset_import = helper.make_opsetid(domain="com.microsoft", version=1) + model = helper.make_model(graph, opset_imports=[opset_import]) + return model + + def create_test_model(self, model_type: str, interleaved: bool, initializers: List[TensorProto]): + inputs, outputs = self.create_inputs_and_outputs(model_type) + matmul_nodes = self.create_matmul_nodes(False, model_type) + rope_nodes = self.create_rotary_embeddings(False, model_type, interleaved, inputs, initializers) + + # Create main paths + q_nodes = self.create_q_path(model_type) + k_nodes = self.create_k_path(model_type) + attn_mask_nodes = self.create_attn_mask_path(False, model_type) + qk_nodes = self.create_qk_path(model_type) + v_nodes = self.create_v_path(model_type) + qkv_nodes = self.create_qkv_path(model_type) + + reshape_nodes = list(filter(lambda node: node.op_type == "Reshape", q_nodes + k_nodes + v_nodes + qkv_nodes)) + extra_nodes = self.create_concat_unsqueeze_paths(model_type, reshape_nodes) + + end_nodes = self.create_end_nodes(model_type) + + first_set_of_nodes = matmul_nodes + rope_nodes + q_nodes + k_nodes + attn_mask_nodes + second_set_of_nodes = qk_nodes + v_nodes + qkv_nodes + extra_nodes + end_nodes + graph = helper.make_graph( + nodes=first_set_of_nodes + second_set_of_nodes, + name="RotaryAttention_Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + opset_import = helper.make_opsetid(domain="ai.onnx", version=17) + model = helper.make_model(graph, opset_imports=[opset_import]) + return model + + def check_models(self, model_type: str, interleaved: bool): + initializers = self.create_initializers() + + expected_model_filename = "expected_model.onnx" + expected_model = self.create_fused_model(model_type, interleaved, initializers) + onnx.save(expected_model, expected_model_filename) + + original_model_filename = "original_model.onnx" + original_model = self.create_test_model(model_type, interleaved, initializers) + onnx.save(original_model, original_model_filename) + + self.verify_fusion(expected_model_filename, original_model_filename) + os.remove(expected_model_filename) + os.remove(original_model_filename) + + def test_llama2_msft_model(self): + model_type = "llama2_msft" + interleaved = True + self.check_models(model_type, interleaved) + + def test_hf_decoder_model(self): + model_type = "no_past" + interleaved = False + self.check_models(model_type, interleaved) + + def test_hf_decoder_with_past_model(self): + model_type = "past" + interleaved = False + self.check_models(model_type, interleaved) + + def test_hf_decoder_merged_model(self): + model_type = "merged" + interleaved = False + self.check_models(model_type, interleaved) + + def test_hf_70b_distributed_decoder_merged_model(self): + model_type = "70b_distributed_merged" + interleaved = False + self.check_models(model_type, interleaved) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_simplified_layernorm_fusion.py b/onnxruntime/test/python/transformers/test_simplified_layernorm_fusion.py new file mode 100644 index 0000000000..e86bdda7ba --- /dev/null +++ b/onnxruntime/test/python/transformers/test_simplified_layernorm_fusion.py @@ -0,0 +1,243 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import unittest +from typing import List + +import numpy as np +import onnx +from onnx import TensorProto, helper +from parity_utilities import find_transformers_source + +if find_transformers_source(): + from fusion_options import FusionOptions + from onnx_model import OnnxModel + from optimizer import optimize_model +else: + from onnxruntime.transformers.fusion_options import FusionOptions + from onnxruntime.transformers.onnx_model import OnnxModel + from onnxruntime.transformers.optimizer import optimize_model + + +def float_tensor(name: str, shape: List[int], random=False): + low = 0.0 + high = 1.0 + total_elements = 1 + for x in shape: + total_elements *= x + weights = [np.random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements + return helper.make_tensor(name, TensorProto.FLOAT, shape, weights) + + +class TestSimplifiedLayerNormFusion(unittest.TestCase): + def setUp(self): + self.vocab_size = 5 + self.batch_size = 2 + self.sequence_length = 8 + self.hidden_size = 16 + self.epsilon = 0.000009999999747378752 + + def verify_fusion(self, expected_model_path, original_model_path): + expected_model = OnnxModel(onnx.load(expected_model_path)) + expected_model.topological_sort(is_deterministic=True) + + options = FusionOptions("gpt2") + optimized_model = optimize_model(original_model_path, optimization_options=options) + optimized_model.topological_sort(is_deterministic=True) + + self.assertTrue(str(expected_model.model.graph), str(optimized_model.model.graph)) + + def create_initializers(self, use_embed_weight: bool = False): + initializers = [ + helper.make_tensor("Two", TensorProto.FLOAT, [1], np.array([2], dtype=np.float32)), + helper.make_tensor("epsilon", TensorProto.FLOAT, [1], np.array([self.epsilon], dtype=np.float32)), + helper.make_tensor("One", TensorProto.FLOAT, [1], np.array([1], dtype=np.float32)), + float_tensor("scale", [self.hidden_size]), + ] + if use_embed_weight: + initializers = [ # noqa: RUF005 + float_tensor("embed_weight", [self.vocab_size, self.hidden_size]) + ] + initializers + return initializers + + def create_inputs_and_outputs(self, start_node_type: str): + inputs, start_node = None, None + if start_node_type == "Add": + start_node = helper.make_node( + "Add", + inputs=["input_0", "input_1"], + outputs=["D"], + name="Add_0", + ) + input_0 = helper.make_tensor_value_info( + "input_0", + TensorProto.FLOAT, + [self.batch_size, self.sequence_length, self.hidden_size], + ) + input_1 = helper.make_tensor_value_info( + "input_1", + TensorProto.FLOAT, + [self.batch_size, self.sequence_length, self.hidden_size], + ) + inputs = [input_0, input_1] + elif start_node_type == "Gather": + start_node = helper.make_node( + "Gather", + inputs=["embed_weight", "input_0"], + outputs=["D"], + name="Gather_0", + ) + input_0 = helper.make_tensor_value_info( + "input_0", + TensorProto.INT64, + [self.batch_size, self.sequence_length], + ) + inputs = [input_0] + else: + # start_node_type is a graph input + assert start_node_type == "GraphInput" + input_0 = helper.make_tensor_value_info( + "D", + TensorProto.FLOAT, + [self.batch_size, self.sequence_length, self.hidden_size], + ) + inputs = [input_0] + + outputs = [ + helper.make_tensor_value_info( + "output_0", + TensorProto.FLOAT, + [self.batch_size, self.sequence_length, self.hidden_size], + ) + ] + return inputs, outputs, start_node + + def create_fused_model(self, start_node_type: str, initializers: List[TensorProto]): + inputs, outputs, start_node = self.create_inputs_and_outputs(start_node_type) + + sln_node = helper.make_node( + "SimplifiedLayerNormalization", + inputs=[start_node.output[0] if start_node is not None else "D", initializers[0].name], + outputs=[outputs[0].name], + axis=-1, + epsilon=initializers[2].float_data[0], + stash_type=1, + ) + + graph = helper.make_graph( + nodes=[sln_node] + ([] if start_node is None else [start_node]), + name="SimplifiedLayerNorm_Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + opset_import = helper.make_opsetid(domain="com.microsoft", version=1) + model = helper.make_model(graph, opset_imports=[opset_import]) + return model + + # Notation follows https://onnx.ai/onnx/operators/onnx__LayerNormalization.html#summary + def create_test_model(self, start_node_type: str, first_parent_idx: int, initializers: List[TensorProto]): + end_node = helper.make_node( + "Mul", + inputs=["scale", "Normalized"] if first_parent_idx == 1 else ["Normalized", "scale"], + outputs=["output_0"], + name="Mul_1", + ) + mul_node = helper.make_node( + "Mul", + inputs=["D", "InvStdDev"], + outputs=["Normalized"], + name="Mul_0", + ) + div_node = helper.make_node( + "Div", + inputs=["One", "StdDev"], + outputs=["InvStdDev"], + name="Div_0", + ) + sqrt_node = helper.make_node( + "Sqrt", + inputs=["VarEps"], + outputs=["StdDev"], + name="Sqrt_0", + ) + add_node = helper.make_node( + "Add", + inputs=["Var", "epsilon"], + outputs=["VarEps"], + name="Add_1", + ) + reducemean_node = helper.make_node( + "ReduceMean", + inputs=["DD"], + outputs=["Var"], + name="ReduceMean_0", + ) + pow_node = helper.make_node( + "Pow", + inputs=["D", "Two"], + outputs=["DD"], + name="Pow_0", + ) + + inputs, outputs, start_node = self.create_inputs_and_outputs(start_node_type) + + main_nodes = [pow_node, reducemean_node, add_node, sqrt_node, div_node, mul_node, end_node] + graph = helper.make_graph( + nodes=main_nodes + ([] if start_node is None else [start_node]), + name="SimplifiedLayerNorm_Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + opset_import = helper.make_opsetid(domain="com.microsoft", version=1) + model = helper.make_model(graph, opset_imports=[opset_import]) + return model + + def check_models(self, start_node_type: str, first_parent_idx: int, initializers: List[TensorProto]): + expected_model_filename = "expected_model.onnx" + expected_model = self.create_fused_model(start_node_type, initializers) + onnx.save(expected_model, expected_model_filename) + + original_model_filename = "original_model.onnx" + original_model = self.create_test_model(start_node_type, first_parent_idx, initializers) + onnx.save(original_model, original_model_filename) + + self.verify_fusion(expected_model_filename, original_model_filename) + os.remove(expected_model_filename) + os.remove(original_model_filename) + + # sim_ln_nodes_1 + def test_simplified_layernorm_add_idx1(self): + start_node_type = "Add" + first_parent_idx = 1 + initializers = self.create_initializers() + self.check_models(start_node_type, first_parent_idx, initializers) + + # sim_ln_nodes_2 + def test_simplified_layernorm_gather_idx1(self): + start_node_type = "Gather" + first_parent_idx = 1 + initializers = self.create_initializers(use_embed_weight=True) + self.check_models(start_node_type, first_parent_idx, initializers) + + # sim_ln_nodes_3 + def test_simplified_layernorm_add_idx0(self): + start_node_type = "Add" + first_parent_idx = 0 + initializers = self.create_initializers() + self.check_models(start_node_type, first_parent_idx, initializers) + + # sim_ln_nodes_4 + def test_simplified_layernorm_gather_graph_input(self): + start_node_type = "GraphInput" + first_parent_idx = 0 + initializers = self.create_initializers() + self.check_models(start_node_type, first_parent_idx, initializers) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_skip_layer_norm_fusion.py b/onnxruntime/test/python/transformers/test_skip_layer_norm_fusion.py new file mode 100644 index 0000000000..5b3a3f18cd --- /dev/null +++ b/onnxruntime/test/python/transformers/test_skip_layer_norm_fusion.py @@ -0,0 +1,276 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import unittest +from typing import Dict, List + +import numpy as np +import onnx +from onnx import TensorProto, helper, numpy_helper +from parity_utilities import find_transformers_source + +if find_transformers_source(): + from fusion_options import FusionOptions + from optimizer import optimize_model +else: + from onnxruntime.transformers.fusion_options import FusionOptions + from onnxruntime.transformers.optimizer import optimize_model + + +def float_tensor(name: str, shape: List[int], random=False): + low = 0.0 + high = 1.0 + total_elements = 1 + for x in shape: + total_elements *= x + weights = [np.random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements + return helper.make_tensor(name, TensorProto.FLOAT, shape, weights) + + +class TestFusion(unittest.TestCase): + def verify_skip_layer_norm_fusion( + self, + model_path: str, + expected_counter: Dict[str, int], + expected_inputs: List[str], + expected_outputs: List[str], + ): + options = FusionOptions("bert") + optimized_model = optimize_model(model_path, optimization_options=options, opt_level=0) + + ops = ["Add", "LayerNormalization", "SkipLayerNormalization", "Cast"] + for op in ops: + nodes = optimized_model.get_nodes_by_op_type(op) + print(op, len(nodes), expected_counter[op]) + self.assertEqual(len(nodes), expected_counter[op]) + + if op == "SkipLayerNormalization" and expected_counter[op] == 1: + print(nodes[0].input) + print(nodes[0].output) + self.assertEqual(nodes[0].input, expected_inputs) + self.assertEqual(nodes[0].output, expected_outputs) + + def create_test_model( + self, + batch_size: int = 1, + sequence_length: int = 2, + hidden_size: int = 3, + add_graph_output: bool = True, + bias: int = 0, # 0 - no bias, 1 - bias in input_1, 2 - bias in input_2 + cast_before_add_bias=False, + ): + matmul = helper.make_node("MatMul", ["input_0", "matmul_weight"], ["matmul_output"], "matmul") + cast_node = helper.make_node("Cast", ["matmul_output"], ["matmul_output_cast"], to=1) + add_bias = helper.make_node( + "Add", + ["matmul_output_cast" if cast_before_add_bias else "matmul_output", "bias"], + ["input_1" if bias == 1 else "input_2"], + "add_bias", + ) + + add_before_layer_norm = helper.make_node("Add", ["input_1", "input_2"], ["layernorm_input"], "add_layernorm") + layer_norm = helper.make_node( + "LayerNormalization", + ["layernorm_input", "layer_norm_weight", "layer_norm_bias"], + ["output"], + "layernorm", + axis=-1, + epsion=0.000009999999747378752, + ) + + initializers = [ # initializers + float_tensor("layer_norm_weight", [hidden_size]), + float_tensor("layer_norm_bias", [hidden_size]), + ] + + if bias > 0: + weight_tensor = float_tensor("matmul_weight", [hidden_size, hidden_size]) + # MatMul weights is float16 when there is Cast node + if cast_before_add_bias: + weight_tensor.CopyFrom( + numpy_helper.from_array(numpy_helper.to_array(weight_tensor).astype(np.float16), weight_tensor.name) + ) + initializers.append(weight_tensor) + + bias_tensor = float_tensor("bias", [hidden_size]) + initializers.append(bias_tensor) + + input_0 = helper.make_tensor_value_info( + "input_0", + TensorProto.FLOAT16 if cast_before_add_bias else TensorProto.FLOAT, + [batch_size, sequence_length, hidden_size], + ) + + input_1 = helper.make_tensor_value_info( + "input_1", + TensorProto.FLOAT, + [batch_size, sequence_length, hidden_size], + ) + + input_2 = helper.make_tensor_value_info( + "input_2", + TensorProto.FLOAT, + [batch_size, sequence_length, hidden_size], + ) + + output = helper.make_tensor_value_info( + "output", + TensorProto.FLOAT, + [batch_size, sequence_length, hidden_size], + ) + + layernorm_input = helper.make_tensor_value_info( + "layernorm_input", + TensorProto.FLOAT, + [batch_size, sequence_length, hidden_size], + ) + + nodes = [add_before_layer_norm, layer_norm] + if bias > 0: + nodes.insert(0, add_bias) + if cast_before_add_bias: + nodes.insert(0, cast_node) + nodes.insert(0, matmul) + + node_name = "SkipLayerNormFusionModel" + if bias == 0: + graph = helper.make_graph( + nodes, + node_name, + [input_1, input_2], # inputs + [output, layernorm_input] if add_graph_output else [output], # outputs + initializers, + ) + elif bias == 1: + graph = helper.make_graph( + nodes, + node_name, + [input_0, input_2], # inputs + [output, layernorm_input] if add_graph_output else [output], # outputs + initializers, + ) + else: + graph = helper.make_graph( + nodes, + node_name, + [input_0, input_1], # inputs + [output, layernorm_input] if add_graph_output else [output], # outputs + initializers, + ) + + onnx_opset = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16)) + return helper.make_model(graph, opset_imports=(onnx_opset,)) + + def test_skip_layer_norm_no_graph_output(self): + model = self.create_test_model(batch_size=1, sequence_length=2, hidden_size=3, add_graph_output=False) + model_name = "skip_layer_norm_add_no_graph_output.onnx" + onnx.save(model, model_name) + self.verify_skip_layer_norm_fusion( + model_name, + { + "Add": 0, + "LayerNormalization": 0, + "SkipLayerNormalization": 1, + "Cast": 0, + }, + ["input_1", "input_2", "layer_norm_weight", "layer_norm_bias"], + ["output"], + ) + os.remove(model_name) + + def test_skip_layer_norm_graph_output(self): + model = self.create_test_model(batch_size=1, sequence_length=2, hidden_size=3, add_graph_output=True) + model_name = "skip_layer_norm_add_has_graph_output.onnx" + onnx.save(model, model_name) + self.verify_skip_layer_norm_fusion( + model_name, + { + "Add": 0, + "LayerNormalization": 0, + "SkipLayerNormalization": 1, + "Cast": 0, + }, + ["input_1", "input_2", "layer_norm_weight", "layer_norm_bias"], + ["output", "", "", "layernorm_input"], + ) + os.remove(model_name) + + def test_skip_layer_norm_graph_output_bias1(self): + model = self.create_test_model(batch_size=1, sequence_length=2, hidden_size=3, add_graph_output=True, bias=1) + model_name = "skip_layer_norm_add_has_graph_output_bias1.onnx" + onnx.save(model, model_name) + self.verify_skip_layer_norm_fusion( + model_name, + { + "Add": 0, + "LayerNormalization": 0, + "SkipLayerNormalization": 1, + "Cast": 0, + }, + ["matmul_output", "input_2", "layer_norm_weight", "layer_norm_bias", "bias"], + ["output", "", "", "layernorm_input"], + ) + os.remove(model_name) + + def test_skip_layer_norm_graph_output_bias2(self): + model = self.create_test_model(batch_size=1, sequence_length=2, hidden_size=3, add_graph_output=True, bias=2) + model_name = "skip_layer_norm_add_has_graph_output_bias1.onnx" + onnx.save(model, model_name) + self.verify_skip_layer_norm_fusion( + model_name, + { + "Add": 0, + "LayerNormalization": 0, + "SkipLayerNormalization": 1, + "Cast": 0, + }, + ["matmul_output", "input_1", "layer_norm_weight", "layer_norm_bias", "bias"], + ["output", "", "", "layernorm_input"], + ) + os.remove(model_name) + + def test_skip_layer_norm_graph_output_cast_bias1(self): + model = self.create_test_model( + batch_size=1, sequence_length=2, hidden_size=3, add_graph_output=True, bias=1, cast_before_add_bias=True + ) + model_name = "skip_layer_norm_add_has_graph_output_cast_bias1.onnx" + onnx.save(model, model_name) + self.verify_skip_layer_norm_fusion( + model_name, + { + "Add": 0, + "LayerNormalization": 0, + "SkipLayerNormalization": 1, + "Cast": 1, + }, + ["matmul_output_cast", "input_2", "layer_norm_weight", "layer_norm_bias", "bias"], + ["output", "", "", "layernorm_input"], + ) + os.remove(model_name) + + def test_skip_layer_norm_graph_output_cast_bias2(self): + model = self.create_test_model( + batch_size=1, sequence_length=2, hidden_size=3, add_graph_output=True, bias=2, cast_before_add_bias=True + ) + model_name = "skip_layer_norm_add_has_graph_output_cast_bias2.onnx" + onnx.save(model, model_name) + self.verify_skip_layer_norm_fusion( + model_name, + { + "Add": 0, + "LayerNormalization": 0, + "SkipLayerNormalization": 1, + "Cast": 1, + }, + ["matmul_output_cast", "input_1", "layer_norm_weight", "layer_norm_bias", "bias"], + ["output", "", "", "layernorm_input"], + ) + os.remove(model_name) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_whisper.py b/onnxruntime/test/python/transformers/test_whisper.py index a2aa6383c2..ceda5a88c3 100644 --- a/onnxruntime/test/python/transformers/test_whisper.py +++ b/onnxruntime/test/python/transformers/test_whisper.py @@ -37,9 +37,20 @@ def verify_fusion(self, optimized_model, expected_model_filename): expected_model = OnnxModel(onnx.load(expected_model_path)) expected_model.topological_sort(is_deterministic=True) - self.assertEqual(str(optimized_model.model.graph), str(expected_model.model.graph)) + nodes = optimized_model.model.graph.node + self.assertEqual(len(nodes), len(expected_model.model.graph.node)) - # Attention type #1 in onnx_model_bart.py + for i in range(len(nodes)): + self.assertEqual(nodes[i], expected_model.model.graph.node[i]) + + for expected_initializer in expected_model.model.graph.initializer: + self.assertTrue( + OnnxModel.has_same_value( + optimized_model.get_initializer(expected_initializer.name), expected_initializer + ) + ) + + # Attention type #1 in fusion_bart_attention.py def test_encoder_attention_fusion_with_skiplayernorm(self): num_heads = 4 hidden_size = 64 @@ -56,7 +67,7 @@ def test_encoder_attention_fusion_with_skiplayernorm(self): os.remove(model_path) self.verify_fusion(optimized_model, "encoder_attention_with_sln_fused.onnx") - # Attention type #2 in onnx_model_bart.py + # Attention type #2 in fusion_bart_attention.py def test_decoder_attention_fusion_with_skiplayernorm(self): num_heads = 4 hidden_size = 64 @@ -73,7 +84,7 @@ def test_decoder_attention_fusion_with_skiplayernorm(self): os.remove(model_path) self.verify_fusion(optimized_model, "decoder_attention_with_sln_fused.onnx") - # Attention type #4 in onnx_model_bart.py + # Attention type #4 in fusion_bart_attention.py def test_decoder_multihead_attention_fusion(self): num_heads = 4 hidden_size = 64 @@ -89,7 +100,7 @@ def test_decoder_multihead_attention_fusion(self): os.remove(model_path) self.verify_fusion(optimized_model, "decoder_mha_fused.onnx") - # Attention type #3 in onnx_model_bart.py + # Attention type #3 in fusion_bart_attention.py def test_decoder_with_past_multihead_self_attention_fusion_with_skiplayernorm(self): num_heads = 4 hidden_size = 64 @@ -107,7 +118,7 @@ def test_decoder_with_past_multihead_self_attention_fusion_with_skiplayernorm(se os.remove(model_path) self.verify_fusion(optimized_model, "decoder_with_past_self_mha_fused.onnx") - # Attention type #5 in onnx_model_bart.py + # Attention type #5 in fusion_bart_attention.py def test_decoder_with_past_multihead_cross_attention_fusion(self): num_heads = 4 hidden_size = 64 @@ -123,7 +134,7 @@ def test_decoder_with_past_multihead_cross_attention_fusion(self): os.remove(model_path) self.verify_fusion(optimized_model, "decoder_with_past_cross_mha_fused.onnx") - # Attention type #4 in onnx_model_bart.py + # Attention type #4 in fusion_bart_attention.py def test_decoder_multihead_attention_split_bias_fusion(self): num_heads = 4 hidden_size = 64 @@ -140,7 +151,7 @@ def test_decoder_multihead_attention_split_bias_fusion(self): os.remove(model_path) self.verify_fusion(optimized_model, "decoder_mha_split_bias_fused.onnx") - # Attention type #3 in onnx_model_bart.py + # Attention type #3 in fusion_bart_attention.py def test_decoder_with_past_multihead_self_attention_split_bias_fusion_with_skiplayernorm(self): num_heads = 4 hidden_size = 64 @@ -160,7 +171,7 @@ def test_decoder_with_past_multihead_self_attention_split_bias_fusion_with_skipl os.remove(model_path) self.verify_fusion(optimized_model, "decoder_with_past_self_mha_split_bias_fused.onnx") - # Attention type #5 in onnx_model_bart.py + # Attention type #5 in fusion_bart_attention.py def test_decoder_with_past_multihead_cross_attention_split_bias_fusion(self): num_heads = 4 hidden_size = 64 diff --git a/onnxruntime/test/testdata/attention_no_mask_fp16.onnx b/onnxruntime/test/testdata/attention_no_mask_fp16.onnx new file mode 100644 index 0000000000..fe8aa0038d Binary files /dev/null and b/onnxruntime/test/testdata/attention_no_mask_fp16.onnx differ diff --git a/onnxruntime/test/testdata/float8/te.cast_fp8_1_fp32.onnx b/onnxruntime/test/testdata/float8/te.cast_fp8_1_fp32.onnx new file mode 100644 index 0000000000..1dec991008 Binary files /dev/null and b/onnxruntime/test/testdata/float8/te.cast_fp8_1_fp32.onnx differ diff --git a/onnxruntime/test/testdata/float8/te.cast_fp8_1_fp32_input.npy b/onnxruntime/test/testdata/float8/te.cast_fp8_1_fp32_input.npy new file mode 100644 index 0000000000..706f508836 Binary files /dev/null and b/onnxruntime/test/testdata/float8/te.cast_fp8_1_fp32_input.npy differ diff --git a/onnxruntime/test/testdata/ort_github_issue_17000.onnx b/onnxruntime/test/testdata/ort_github_issue_17000.onnx new file mode 100644 index 0000000000..8320c19cb6 Binary files /dev/null and b/onnxruntime/test/testdata/ort_github_issue_17000.onnx differ diff --git a/onnxruntime/test/testdata/ort_github_issue_17000.ort b/onnxruntime/test/testdata/ort_github_issue_17000.ort new file mode 100644 index 0000000000..08d9826dd5 Binary files /dev/null and b/onnxruntime/test/testdata/ort_github_issue_17000.ort differ diff --git a/onnxruntime/test/testdata/ort_github_issue_17000.py b/onnxruntime/test/testdata/ort_github_issue_17000.py new file mode 100644 index 0000000000..43c10f5590 --- /dev/null +++ b/onnxruntime/test/testdata/ort_github_issue_17000.py @@ -0,0 +1,77 @@ +import numpy as np +import onnx +from onnx import TensorProto, helper, numpy_helper + + +def order_repeated_field(repeated_proto, key_name, order): + order = list(order) + repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name))) + + +def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs): + node = helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs) + if doc_string == "": + node.doc_string = "" + order_repeated_field(node.attribute, "name", kwargs.keys()) + return node + + +def make_graph(*args, doc_string=None, **kwargs): + graph = helper.make_graph(*args, doc_string=doc_string, **kwargs) + if doc_string == "": + graph.doc_string = "" + return graph + + +test_graph = make_graph( + name="test_graph", + # model input of a sequence type to test IsSparseTensor issue + inputs=[ + helper.make_tensor_sequence_value_info("seq_in", TensorProto.FLOAT, shape=None), + ], + outputs=[ + helper.make_tensor_value_info("still_has_elements", TensorProto.BOOL, shape=[]), + ], + initializer=[ + numpy_helper.from_array(np.array(0, dtype="int64"), name="i0"), + ], + nodes=[ + make_node("SequenceLength", inputs=["seq_in"], outputs=["seq_len"], name="get_seq_len"), + make_node("Greater", inputs=["seq_len", "i0"], outputs=["has_elements"], name="get_has_elements"), + # If node with one branch that has no nodes to test the allocation planner issue + # if sequence has elements: + # remove one + # output bool of whether it still has elements + # else: + # output false (gives us branch with no nodes) + make_node( + "If", + name="test_if", + inputs=["has_elements"], + outputs=["still_has_elements"], + then_branch=make_graph( + name="then", + inputs=[], + outputs=[helper.make_tensor_value_info("then_bool_out", TensorProto.BOOL, shape=[])], + nodes=[ + make_node("SequenceErase", inputs=["seq_in", "i0"], outputs=["seq_less_one"]), + make_node("SequenceLength", inputs=["seq_less_one"], outputs=["new_seq_len"]), + make_node("Greater", inputs=["new_seq_len", "i0"], outputs=["then_bool_out"]), + ], + ), + else_branch=make_graph( + name="else", + initializer=[numpy_helper.from_array(np.array(False, dtype="bool"), name="else_bool_out")], + inputs=[], + outputs=[helper.make_tensor_value_info("else_bool_out", TensorProto.BOOL, shape=[])], + nodes=[], + ), + ), + ], +) + +# Graph with Sequence operations and an If node that has a subgraph with no nodes +model = helper.make_model(opset_imports=[helper.make_operatorsetid("ai.onnx", 14)], ir_version=7, graph=test_graph) + +onnx.shape_inference.infer_shapes(model, strict_mode=True) +onnx.save(model, "ort_github_issue_17000.onnx") diff --git a/onnxruntime/test/testdata/required_ops.config b/onnxruntime/test/testdata/required_ops.config index ac9d46666e..e70362bab4 100644 --- a/onnxruntime/test/testdata/required_ops.config +++ b/onnxruntime/test/testdata/required_ops.config @@ -3,9 +3,9 @@ ai.onnx;7;Abs,Add,And,BatchNormalization,Concat,Conv,Dropout,Flatten,Foo,Gather, ai.onnx;8;Add,Conv,Flatten,Gemm,MatMul,MaxPool,Mul,Relu,Reshape ai.onnx;9;Abs,Add,BatchNormalization,Cast,Clip,Concat,Constant,ConstantOfShape,Conv,Div,Equal,Gather,Gemm,Identity,If,LayerNormalization,LeakyRelu,Loop,MatMul,Mul,Pow,ReduceMean,Relu,Reshape,Scan,Shape,Sigmoid,Slice,Softmax,Softsign,Sqrt,Squeeze,Sub,Tanh,Transpose,Unsqueeze ai.onnx;10;Add,Cast,Concat,ConstantOfShape,Div,Dropout,Erf,Expand,Gather,Greater,Identity,If,LayerNormalization,Loop,MatMul,Mul,Neg,NonZero,Pow,ReduceMean,ReduceSum,Shape,Sqrt,Squeeze,Sub,Tanh,Transpose,Unsqueeze -ai.onnx;11;Abs,Add,ArgMax,BatchNormalization,Cast,Clip,Concat,Constant,ConstantOfShape,Conv,Div,Equal,Exp,Expand,Flatten,Gather,Gemm,Identity,If,LayerNormalization,Log,Loop,MatMul,MatMulInteger,Max,Min,Mul,Neg,Pow,RandomUniform,Range,ReduceMean,ReduceSum,ReduceSumSquare,Relu,Reshape,Scan,SequenceConstruct,SequenceInsert,SequenceLength,Shape,Sigmoid,Slice,Softmax,Split,Sqrt,Squeeze,Sub,Sum,Tanh,Transpose,Unsqueeze,Where +ai.onnx;11;Abs,Add,ArgMax,BatchNormalization,Cast,Clip,Concat,Constant,ConstantOfShape,Conv,Div,Equal,Exp,Expand,Flatten,Gather,Gemm,Identity,If,LayerNormalization,Log,Loop,MatMul,MatMulInteger,Max,Min,Mul,Neg,Pow,RandomUniform,Range,ReduceMean,ReduceSum,ReduceSumSquare,Relu,Reshape,Scan,SequenceConstruct,SequenceErase,SequenceInsert,SequenceLength,Shape,Sigmoid,Slice,Softmax,Split,Sqrt,Squeeze,Sub,Sum,Tanh,Transpose,Unsqueeze,Where ai.onnx;12;Add,And,Cast,Concat,Constant,ConstantOfShape,Conv,CumSum,Div,Dropout,DynamicQuantizeLinear,Equal,Erf,Expand,Flatten,Gather,GatherND,Gemm,GlobalAveragePool,Greater,Identity,If,IsInf,LayerNormalization,Less,Loop,MatMul,MatMulInteger,Min,Mul,Not,Pad,Pow,RandomNormalLike,RandomUniform,ReduceMean,ReduceSum,Relu,Reshape,Shape,Slice,Softmax,SoftmaxCrossEntropyLoss,SparseSoftmaxCrossEntropy,Split,Sqrt,Squeeze,Sub,Tanh,Transpose,Unsqueeze,Where -ai.onnx;13;Abs,Add,Cast,Concat,ConstantOfShape,Conv,DequantizeLinear,DynamicQuantizeLinear,Equal,Expand,FooBar,FooBar_Attr,Gather,Identity,LayerNormalization,MatMul,MatMulInteger,Mul,Pad,Pow,QuantizeLinear,Range,ReduceSum,Reshape,Shape,Tanh,Transpose,Unsqueeze,Where +ai.onnx;13;Abs,Add,Cast,Concat,ConstantOfShape,Conv,DequantizeLinear,DynamicQuantizeLinear,Equal,Expand,FooBar,FooBar_Attr,Gather,Greater,Identity,If,LayerNormalization,MatMul,MatMulInteger,Mul,Pad,Pow,QuantizeLinear,Range,ReduceSum,Reshape,Shape,Tanh,Transpose,Unsqueeze,Where ai.onnx;14;Add,ArgMax,Cast,Conv,Identity,Relu,Sigmoid,Sub ai.onnx;314159;Add ai.onnx.contrib;1;StringLower diff --git a/onnxruntime/test/testdata/required_ops_and_types.config b/onnxruntime/test/testdata/required_ops_and_types.config index 17687906d7..41f3742147 100644 --- a/onnxruntime/test/testdata/required_ops_and_types.config +++ b/onnxruntime/test/testdata/required_ops_and_types.config @@ -1,9 +1,12 @@ # required ops and types for ORT format models in testdata -ai.onnx;1;Conv{"inputs": {"0": ["float"]}},Foo,Identity +ai.onnx;1;Conv{"inputs": {"0": ["float"]}} ai.onnx;5;Reshape ai.onnx;6;Relu{"inputs": {"0": ["float"]}} ai.onnx;7;Add{"inputs": {"0": ["float"]}},Gemm{"inputs": {"0": ["float"]}},Mul{"inputs": {"0": ["float"]}} ai.onnx;8;MaxPool{"inputs": {"0": ["float"]}},Sum{"inputs": {"0": ["float"]}} ai.onnx;9;Cast{"inputs": {"0": ["float"]}, "outputs": {"0": ["bool"]}} -ai.onnx;11;ArgMax{"inputs": {"0": ["float"]}},If,Loop +ai.onnx;10;QLinearConv{"inputs": {"0": ["uint8_t"]}} +ai.onnx;11;ArgMax{"inputs": {"0": ["float"]}},Clip{"inputs": {"0": ["float"]}},Conv{"inputs": {"0": ["float"]}},If,Loop,SequenceErase,SequenceLength +ai.onnx;13;DequantizeLinear{"inputs": {"0": ["int32_t", "uint8_t"]}},Greater{"inputs": {"0": ["int64_t"]}},If,QuantizeLinear{"outputs": {"0": ["uint8_t"]}} ai.onnx.ml;1;ArrayFeatureExtractor,LinearClassifier,Normalizer,ZipMap +test;1;Foo diff --git a/onnxruntime/test/testdata/training_api/ort_format/checkpoint b/onnxruntime/test/testdata/training_api/ort_format/checkpoint new file mode 100644 index 0000000000..ab35c9ad5a Binary files /dev/null and b/onnxruntime/test/testdata/training_api/ort_format/checkpoint differ diff --git a/onnxruntime/test/testdata/training_api/ort_format/eval_model.ort b/onnxruntime/test/testdata/training_api/ort_format/eval_model.ort new file mode 100644 index 0000000000..69b2c7e029 Binary files /dev/null and b/onnxruntime/test/testdata/training_api/ort_format/eval_model.ort differ diff --git a/onnxruntime/test/testdata/training_api/ort_format/optimizer_model.ort b/onnxruntime/test/testdata/training_api/ort_format/optimizer_model.ort new file mode 100644 index 0000000000..88f1924623 Binary files /dev/null and b/onnxruntime/test/testdata/training_api/ort_format/optimizer_model.ort differ diff --git a/onnxruntime/test/testdata/training_api/ort_format/prepare_artifacts.py b/onnxruntime/test/testdata/training_api/ort_format/prepare_artifacts.py new file mode 100644 index 0000000000..70e8c4ac01 --- /dev/null +++ b/onnxruntime/test/testdata/training_api/ort_format/prepare_artifacts.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""This file is used to generate test data for ort format model tests in + orttraining/orttraining/test/training_api/core/training_capi_tests.cc.""" + +import onnx +import torch +import torch.nn as nn + +from onnxruntime.training import artifacts + + +class SimpleNet(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super().__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + out = self.fc1(x) + out = self.relu(out) + out = self.fc2(out) + return out + + +def model_export(pt_model, model_path, input_size): + # Generate random input data + input_data = torch.randn(32, input_size) + torch.onnx.export( + pt_model, + input_data, + model_path, + input_names=["input"], + output_names=["output"], + dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, + ) + + +def main(): + # Set the dimensions for input, hidden, and output layers + input_size = 10 + hidden_size = 20 + output_size = 5 + + # Create an instance of the neural network + pt_model = SimpleNet(input_size, hidden_size, output_size) + + train_model_path = "simplenet_training.onnx" + model_export(pt_model, train_model_path, input_size) + + onnx_model = onnx.load(train_model_path) + + requires_grad = ["fc2.weight", "fc2.bias"] + frozen_params = [param.name for param in onnx_model.graph.initializer if param.name not in requires_grad] + + # Generate the training artifacts. + artifacts.generate_artifacts( + onnx_model, + requires_grad=requires_grad, + frozen_params=frozen_params, + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + ort_format=True, + ) + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/test/testdata/training_api/ort_format/training_model.ort b/onnxruntime/test/testdata/training_api/ort_format/training_model.ort new file mode 100644 index 0000000000..94bda328a9 Binary files /dev/null and b/onnxruntime/test/testdata/training_api/ort_format/training_model.ort differ diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 28af61e15b..e224507bc7 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -268,7 +268,7 @@ std::unique_ptr DefaultCannExecutionProvider() { std::unique_ptr DefaultDmlExecutionProvider() { #ifdef USE_DML - if (auto factory = DMLProviderFactoryCreator::Create(0)) + if (auto factory = DMLProviderFactoryCreator::Create(0, false, false, false)) return factory->CreateProvider(); #endif return nullptr; diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index f8e0545574..a7e8135c7e 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -2070,5 +2070,30 @@ IMPLEMENT_GRADIENT_BUILDER(GetLeakyReluGradient) { {GO(0), O(0)}, {GI(0)}, SrcNodeAttributes())}; } +IMPLEMENT_GRADIENT_BUILDER(GetConvTransposeGradient) { + std::vector outputs; + for (int i = 0; i < GetSrcNodeInputSize(); i++) { + if (IsGradientRequiredForSrcNodeInput(i)) { + outputs.push_back(GI(i)); + } else { + outputs.push_back(ArgDef("", nullptr)); + } + } + + return std::vector{ + NodeDef(OpDef{"ConvTransposeGrad", kMSDomain, 1}, + {GO(0), I(0), I(1)}, + outputs, + SrcNodeAttributes())}; +} + +IMPLEMENT_GRADIENT_BUILDER(GetResizeGradient) { + return std::vector{ + NodeDef(OpDef{"ResizeGrad", kMSDomain, 1}, + {GO(0), I(0), I(1), I(2)}, + {GI(0)}, + SrcNodeAttributes())}; +} + } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index ca86777d36..df1819cd54 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -88,6 +88,8 @@ DECLARE_GRADIENT_BUILDER(GetLSTMGradient) DECLARE_GRADIENT_BUILDER(GetGRUGradient) DECLARE_GRADIENT_BUILDER(GetReciprocalGradient) DECLARE_GRADIENT_BUILDER(GetLeakyReluGradient) +DECLARE_GRADIENT_BUILDER(GetConvTransposeGradient) +DECLARE_GRADIENT_BUILDER(GetResizeGradient) DECLARE_GRADIENT_BUILDER(GetExternalGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index cc9a762ff8..e6d6f792ac 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -120,6 +120,8 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("GRUTraining", GetGRUGradient); REGISTER_GRADIENT_BUILDER("Reciprocal", GetReciprocalGradient); REGISTER_GRADIENT_BUILDER("LeakyRelu", GetLeakyReluGradient); + REGISTER_GRADIENT_BUILDER("ConvTranspose", GetConvTransposeGradient); + REGISTER_GRADIENT_BUILDER("Resize", GetResizeGradient); REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient); }; diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 60867accb8..840cc9db46 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -4908,6 +4908,41 @@ Return true if all elements are true and false otherwise. } } }); + + ONNX_CONTRIB_OPERATOR_SCHEMA(ConvTransposeGrad) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Input(0, "dY", "Gradient of output Y", "T") + .Input(1, "X", "Input tensor", "T") + .Input(2, "W", "Weight tensor", "T") + .Output(0, "dX", "Gradient of X", "T", OpSchema::Optional) + .Output(1, "dW", "Gradient of W", "T", OpSchema::Optional) + .Output(2, "dB", "Gradient of B", "T", OpSchema::Optional) + .AllowUncheckedAttributes() + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Constrain input and output types to float tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(ResizeGrad) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Input(0, "dY", "Gradient of output Y.", "T") + .Input(1, "X", "Input tensor to the Resize operator.", "T") + .Input(2, "roi", "The roi input to the Resize operator.", "T", OpSchema::Optional) + .Input(3, "scales", "The scales input to the Resize operator.", "tensor(float)", OpSchema::Optional) + .Output(0, "dX", "Gradient of the input X.", "T") + .AllowUncheckedAttributes() + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Constrain input and output types to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 1, 0); + if (hasInputShape(ctx, 1)) { + propagateShapeFromInputToOutput(ctx, 1, 0); + } + }); } } // namespace training diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index eac17f3d4d..35d9755ba0 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -174,10 +174,11 @@ struct PyOptimizer { PyOptimizer(const std::string optimizer_model_uri, onnxruntime::training::api::CheckpointState* state, std::vector> providers, PySessionOptions* session_options) : optimizer_() { + auto model_identifiers = onnxruntime::training::api::ModelIdentifiers("", std::nullopt, optimizer_model_uri); auto env = GetTrainingEnv().GetORTEnv(); // XXX: We hope that env will be around when optimizer needs it. optimizer_ = std::make_shared( - optimizer_model_uri, state, session_options->value, *env, providers, session_options->custom_op_domains_); + model_identifiers, state, session_options->value, *env, providers, session_options->custom_op_domains_); } std::shared_ptr optimizer_; @@ -941,9 +942,10 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn OrtDevice device, PySessionOptions* session_options) { std::vector> provider = GetExecutionProvidersForTrainingApis(device); auto env = GetTrainingEnv().GetORTEnv(); - return std::make_unique( - model_uri, state, session_options->value, *env, provider, eval_model_uri, - session_options->custom_op_domains_); + auto model_identifiers = onnxruntime::training::api::ModelIdentifiers(model_uri, eval_model_uri, std::nullopt); + return std::make_unique(model_identifiers, + state, session_options->value, *env, provider, + session_options->custom_op_domains_); })) .def("train_step", [](onnxruntime::training::api::Module* model, @@ -1063,17 +1065,60 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn checkpoint_state(m, "CheckpointState", R"pbdoc(CheckpointState.)pbdoc"); checkpoint_state .def(py::init()) - .def("add_property", [](onnxruntime::training::api::CheckpointState* state, - const std::string& property_name, - const std::variant& property_value) { - state->property_bag.AddProperty(property_name, property_value); - }) - .def("get_property", [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { - return state->property_bag.GetProperty(property_name); - }) - .def("has_property", [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { - return state->property_bag.HasProperty(property_name); - }); + .def("add_property", + [](onnxruntime::training::api::CheckpointState* state, + const std::string& property_name, + const std::variant& property_value) { + state->property_bag.AddProperty(property_name, property_value); + }) + .def("get_property", + [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { + return state->property_bag.GetProperty(property_name); + }) + .def("has_property", + [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { + return state->property_bag.HasProperty(property_name); + }) + .def("copy_parameter_from", + [](onnxruntime::training::api::CheckpointState* state, + const std::string& parameter_name, OrtValue& value) -> void { + auto it = state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == state->module_checkpoint_state.named_parameters.end()) { + ORT_THROW("Parameter with name ", parameter_name, " does not exist."); + } + ORT_THROW_IF_ERROR(it->second->CopyFrom( + state->module_checkpoint_state.train_session_data_transfer_mgr, value)); + }) + .def("get_parameter", + [](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) { + auto it = state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == state->module_checkpoint_state.named_parameters.end()) { + ORT_THROW("Parameter with name ", parameter_name, " does not exist."); + } + return it->second; + }) + .def("has_parameter", + [](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) { + return state->module_checkpoint_state.named_parameters.count(parameter_name); + }) + .def("parameter_names", + [](onnxruntime::training::api::CheckpointState* state) { + std::vector names; + for ([[maybe_unused]] auto& [name, value] : state->module_checkpoint_state.named_parameters) { + names.push_back(name); + } + std::sort(names.begin(), names.end()); + return names; + }) + .def("property_names", + [](onnxruntime::training::api::CheckpointState* state) { + std::vector names; + for ([[maybe_unused]] auto& [name, value] : state->property_bag) { + names.push_back(name); + } + std::sort(names.begin(), names.end()); + return names; + }); py::class_ training_optimizer(m, "Optimizer", R"pbdoc(Training Optimizer.)pbdoc"); @@ -1109,6 +1154,21 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ORT_THROW_IF_ERROR(scheduler->Step()); }); + py::class_> + parameter(m, "Parameter"); + parameter + .def_property_readonly("name", &onnxruntime::training::api::Parameter::Name) + .def_property_readonly("data", &onnxruntime::training::api::Parameter::Data) + .def_property_readonly("grad", &onnxruntime::training::api::Parameter::Gradient) + .def_property_readonly("requires_grad", &onnxruntime::training::api::Parameter::RequiresGrad) + .def("copy_from", + [](onnxruntime::training::api::Parameter* parameter, + onnxruntime::training::api::CheckpointState* state, + OrtValue& value) -> void { + ORT_THROW_IF_ERROR(parameter->CopyFrom(state->module_checkpoint_state.train_session_data_transfer_mgr, value)); + }); + m.def( "save_checkpoint", [](const std::vector& trainable_tensor_protos_pybytes, diff --git a/orttraining/orttraining/python/training/api/checkpoint_state.py b/orttraining/orttraining/python/training/api/checkpoint_state.py index 285264bbed..ba95cd04fc 100644 --- a/orttraining/orttraining/python/training/api/checkpoint_state.py +++ b/orttraining/orttraining/python/training/api/checkpoint_state.py @@ -5,70 +5,171 @@ import os +import numpy as np + from onnxruntime.capi import _pybind_state as C +from onnxruntime.capi.onnxruntime_inference_collection import OrtValue -class CheckpointState: - """Class that holds the state of the training session +class Parameter: + """Class that represents a model parameter - This class holds all the state information of the training session such as the model parameters, - its gradients, the optimizer state and user defined properties. + This class represents a model parameter and provides access to its data, + gradient and other properties. This class is not expected to be instantiated directly. + Instead, it is returned by the `CheckpointState` object. + + Args: + parameter: The C.Parameter object that holds the underlying parameter data. + state: The C.CheckpointState object that holds the underlying session state. + """ + + def __init__(self, parameter: C.Parameter, state: C.CheckpointState): + self._parameter = parameter + self._state = state - User defined properties can be indexed by name from the `CheckpointState` object. + @property + def name(self) -> str: + """The name of the parameter""" + return self._parameter.name - To create the `CheckpointState`, use the `CheckpointState.load_checkpoint` method. + @property + def data(self) -> np.ndarray: + """The data of the parameter""" + return self._parameter.data.numpy() + + @data.setter + def data(self, value: np.ndarray) -> None: + """Sets the data of the parameter""" + self._parameter.copy_from(self._state, OrtValue.ortvalue_from_numpy(value)._ortvalue) + + @property + def grad(self) -> np.ndarray: + """The gradient of the parameter""" + return self._parameter.grad.numpy() if self._parameter.grad.has_value() else None + + @property + def requires_grad(self) -> bool: + """Whether or not the parameter requires its gradient to be computed""" + return self._parameter.requires_grad + + def __repr__(self) -> str: + """Returns a string representation of the parameter""" + return f"Parameter(name={self.name}, requires_grad={self.requires_grad})" + + +class Parameters: + """Class that holds all the model parameters + + This class holds all the model parameters and provides access to them. + This class is not expected to be instantiated directly. Instead, it is returned by the + `CheckpointState`'s parameters attribute. + This class behaves like a dictionary and provides access to the parameters by name. Args: - state: The C.Checkpoint state object that holds the underlying session state. + state: The C.CheckpointState object that holds the underlying session state. """ def __init__(self, state: C.CheckpointState): - if not isinstance(state, C.CheckpointState): - raise TypeError(f"Invalid argument for CheckpointState received {type(state)}") self._state = state - @classmethod - def load_checkpoint(cls, checkpoint_uri: str | os.PathLike) -> CheckpointState: - """Loads the checkpoint state from the checkpoint file + def __getitem__(self, name: str) -> Parameter: + """Gets the parameter associated with the given name + + Searches for the name in the parameters of the checkpoint state. Args: - checkpoint_uri: The path to the checkpoint file. + name: The name of the parameter Returns: - CheckpointState: The checkpoint state object. + The value of the parameter + + Raises: + KeyError: If the parameter is not found """ - return cls(C.load_checkpoint(os.fspath(checkpoint_uri))) - @classmethod - def save_checkpoint( - cls, state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False - ) -> None: - """Saves the checkpoint state to the checkpoint file + if name not in self: + raise KeyError(f"Parameter {name} not found.") + + return Parameter(self._state.get_parameter(name), self._state) + + def __setitem__(self, name: str, value: np.ndarray) -> None: + """Sets the parameter value for the given name + + Searches for the name in the parameters of the checkpoint state. + If the name is found in parameters, the value is updated. Args: - state: The checkpoint state object. - checkpoint_uri: The path to the checkpoint file. - include_optimizer_state: If True, the optimizer state is also saved to the checkpoint file. + name: The name of the parameter + value: The value of the parameter as a numpy array + + Raises: + KeyError: If the parameter is not found """ - C.save_checkpoint(state._state, os.fspath(checkpoint_uri), include_optimizer_state) + if name not in self: + raise KeyError(f"Parameter {name} not found.") + + self._state.copy_parameter_from(name, OrtValue.ortvalue_from_numpy(value)._ortvalue) + + def __contains__(self, name: str) -> bool: + """Checks if the parameter exists in the state + + Args: + name: The name of the parameter + + Returns: + True if the name is a parameter False otherwise + """ + + return self._state.has_parameter(name) + + def __iter__(self): + """Returns an iterator over the properties""" + for parameter_name in self._state.parameter_names(): + yield parameter_name, Parameter(self._state.get_parameter(parameter_name), self._state) + + def __repr__(self) -> str: + """Returns a string representation of the parameters""" + return self._state.parameter_names() + + def __len__(self) -> int: + """Returns the number of parameters""" + return len(self._state.parameter_names()) + + +class Properties: + def __init__(self, state: C.CheckpointState): + self._state = state def __getitem__(self, name: str) -> int | float | str: """Gets the property associated with the given name + Searches for the name in the properties of the checkpoint state. + Args: name: The name of the property Returns: The value of the property + + Raises: + KeyError: If the property is not found """ + + if name not in self: + raise KeyError(f"Property {name} not found.") + return self._state.get_property(name) def __setitem__(self, name: str, value: int | float | str) -> None: """Sets the property value for the given name + Searches for the name in the properties of the checkpoint state. + The value is added or updated in the properties. + Args: name: The name of the property value: The value of the property + Properties only support int, float and str values. """ self._state.add_property(name, value) @@ -79,6 +180,75 @@ def __contains__(self, name: str) -> bool: name: The name of the property Returns: - True if the property exists, False otherwise + True if the name is a property, False otherwise """ + return self._state.has_property(name) + + def __iter__(self): + """Returns an iterator over the properties""" + for property_name in self._state.property_names(): + yield property_name, self._state.get_property(property_name) + + def __repr__(self) -> str: + """Returns a string representation of the properties""" + return self._state.property_names() + + def __len__(self) -> int: + """Returns the number of properties""" + return len(self._state.property_names()) + + +class CheckpointState: + """Class that holds the state of the training session + + This class holds all the state information of the training session such as the model parameters, + its gradients, the optimizer state and user defined properties. + + To create the `CheckpointState`, use the `CheckpointState.load_checkpoint` method. + + Args: + state: The C.Checkpoint state object that holds the underlying session state. + """ + + def __init__(self, state: C.CheckpointState): + if not isinstance(state, C.CheckpointState): + raise TypeError(f"Invalid argument for CheckpointState received {type(state)}") + self._state = state + self._parameters = Parameters(self._state) + self._properties = Properties(self._state) + + @classmethod + def load_checkpoint(cls, checkpoint_uri: str | os.PathLike) -> CheckpointState: + """Loads the checkpoint state from the checkpoint file + + Args: + checkpoint_uri: The path to the checkpoint file. + + Returns: + CheckpointState: The checkpoint state object. + """ + return cls(C.load_checkpoint(os.fspath(checkpoint_uri))) + + @classmethod + def save_checkpoint( + cls, state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False + ) -> None: + """Saves the checkpoint state to the checkpoint file + + Args: + state: The checkpoint state object. + checkpoint_uri: The path to the checkpoint file. + include_optimizer_state: If True, the optimizer state is also saved to the checkpoint file. + """ + C.save_checkpoint(state._state, os.fspath(checkpoint_uri), include_optimizer_state) + + @property + def parameters(self) -> Parameters: + """Returns the model parameters from the checkpoint state""" + return self._parameters + + @property + def properties(self) -> Properties: + """Returns the properties from the checkpoint state""" + return self._properties diff --git a/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py b/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py index 7b24bb400b..1213342004 100644 --- a/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py +++ b/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import copy +import logging import os from typing import List, Optional, Set, Tuple, Union @@ -70,13 +71,16 @@ def _move_initializers_to_inputs(model: onnx.ModelProto, initializer_names: Opti def _gradient_model_for( model: onnx.ModelProto, requires_grad: Set[str], - output_names: List[str], loss_name: str, options: Optional[SessionOptions] = None, ) -> onnx.ModelProto: """Builds the gradient graph on top of the given input forward only graph.""" - builder = GradientGraphBuilder(model.SerializeToString(), set(output_names), requires_grad, loss_name, options) + logging.debug( + "The loss output is %s. The gradient graph will be built starting from %s_grad.", loss_name, loss_name + ) + + builder = GradientGraphBuilder(model.SerializeToString(), {loss_name}, requires_grad, loss_name, options) builder.build() return onnx.load_from_string(builder.get_model()) @@ -123,7 +127,7 @@ def build_gradient_graph( optimized_model = onnx.load_from_string(get_optimized_model(model.SerializeToString(), requires_grad, options)) # Assumption is that the first graph output is the loss output - gradient_model = _gradient_model_for(optimized_model, requires_grad, output_names, output_names[0], options) + gradient_model = _gradient_model_for(optimized_model, requires_grad, output_names[0], options) _reorder_outputs(gradient_model, output_names, requires_grad) diff --git a/orttraining/orttraining/python/training/onnxblock/onnxblock.py b/orttraining/orttraining/python/training/onnxblock/onnxblock.py index 9f90a5a0c3..a2922353ac 100644 --- a/orttraining/orttraining/python/training/onnxblock/onnxblock.py +++ b/orttraining/orttraining/python/training/onnxblock/onnxblock.py @@ -205,6 +205,8 @@ def __call__(self, *args, **kwargs): model, self._requires_grad, self._frozen_params, output, accessor._GLOBAL_CUSTOM_OP_LIBRARY ) + logging.debug("Adding gradient accumulation nodes for training block %s", self.__class__.__name__) + _training_graph_utils.build_gradient_accumulation_graph(self._training_model, self._requires_grad) accessor._GLOBAL_ACCESSOR.model.CopyFrom(self._training_model) diff --git a/orttraining/orttraining/python/training/ortmodule/__init__.py b/orttraining/orttraining/python/training/ortmodule/__init__.py index 150f41eaec..59cf05bb08 100644 --- a/orttraining/orttraining/python/training/ortmodule/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/__init__.py @@ -18,7 +18,7 @@ from .torch_cpp_extensions import is_installed as is_torch_cpp_extensions_installed if not is_ortmodule_available(): - raise RuntimeError("ORTModule is not supported on this platform.") + raise ImportError("ORTModule is not supported on this platform.") def _defined_from_envvar(name, default_value, warn=True): diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 156c3e001d..7731724272 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -271,8 +271,3 @@ def upsample_nearest2d_gradient(): @register_gradient("org.pytorch.aten", "ATen", "upsample_nearest3d", "vec") def upsample_nearest3d_gradient(): return _upsample_gradient("upsample_nearest3d_backward", 3) - - -@register_gradient("org.pytorch.aten", "ATen", "upsample_bilinear2d", "vec") -def upsample_bilinear2d_gradient(): - return _upsample_gradient("upsample_bilinear2d_backward", 2) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index ac87dc6abf..938bc568b6 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -829,16 +829,3 @@ def upsample_nearest2d(g, input, output_size, scale_factors): @register_symbolic("upsample_nearest3d") def upsample_nearest3d(g, input, output_size, scale_factors): return _upsample_nearest(g, input, output_size, scale_factors, "upsample_nearest3d") - - -@register_symbolic("upsample_bilinear2d") -def upsample_bilinear2d(g, input, output_size, align_corners, scale_factors): - return g.op( - "org.pytorch.aten::ATen", - input, - output_size, - align_corners, - scale_factors, - operator_s="upsample_bilinear2d", - overload_name_s="vec", - ) diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 39cc6bdd11..d236c0eca5 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -3039,6 +3039,239 @@ TEST(GradientCheckerTest, LeakyReluGrad) { UnaryOpGradientTest("LeakyRelu", kOnnxDomain, 16, nullptr, &transformer); } +#ifdef USE_CUDA +void ConvTransposeGradientCheckerTest(std::vector>* execution_providers) { + float max_error; + GradientChecker gradient_checker; + OpDef op_def{"ConvTranspose"}; + + float error_tolerance = 3e-1f; + + // 1D convolution + { + TensorShape x_shape({2, 2, 5}); + TensorShape w_shape({2, 2, 3}); + TensorShape b_shape({2}); + TensorShape y_shape({2, 2, 5}); + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError( + op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3}), MakeAttribute("pads", std::vector{1, 1})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 1D strided convolution + { + TensorShape x_shape({2, 1, 7}); + TensorShape w_shape({1, 1, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 13}); + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError( + op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3}), MakeAttribute("pads", std::vector{1, 1}), + MakeAttribute("strides", std::vector{2})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 1D pointwise convolution (with padding) + { + TensorShape x_shape({2, 1, 5}); + TensorShape w_shape({1, 1, 1}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 3}); + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError( + op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{1}), MakeAttribute("pads", std::vector{1, 1})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 1D pointwise convolution (no padding) + { + TensorShape x_shape({2, 1, 5}); + TensorShape w_shape({1, 1, 1}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 5}); + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError( + op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{1}), MakeAttribute("pads", std::vector{0, 0})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 2D convolution + { + TensorShape x_shape({1, 1, 3, 3}); + TensorShape w_shape({1, 1, 3, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({1, 1, 3, 3}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3, 3}), + MakeAttribute("pads", std::vector{1, 1, 1, 1})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 2D convolution + { + TensorShape x_shape({2, 1, 5, 5}); + TensorShape w_shape({1, 1, 3, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 5, 5}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3, 3}), + MakeAttribute("pads", std::vector{1, 1, 1, 1})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 2D pointwise convolution (with padding) + { + TensorShape x_shape({1, 1, 3, 3}); + TensorShape w_shape({1, 1, 1, 1}); + TensorShape b_shape({1}); + TensorShape y_shape({1, 1, 1, 1}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{1, 1}), + MakeAttribute("pads", std::vector{1, 1, 1, 1})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 2D pointwise convolution (no padding) + { + TensorShape x_shape({1, 1, 3, 3}); + TensorShape w_shape({1, 1, 1, 1}); + TensorShape b_shape({1}); + TensorShape y_shape({1, 1, 3, 3}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{1, 1}), + MakeAttribute("pads", std::vector{0, 0, 0, 0})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 2D strided convolution + { + TensorShape x_shape({2, 1, 7, 5}); + TensorShape w_shape({1, 1, 3, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 13, 9}); + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError( + op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3, 3}), + MakeAttribute("pads", std::vector{1, 1, 1, 1}), MakeAttribute("strides", std::vector{2, 2})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 2D dilated convolution (no padding) + { + TensorShape x_shape({2, 1, 5, 5}); + TensorShape w_shape({1, 1, 3, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 9, 9}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3, 3}), + MakeAttribute("pads", std::vector{0, 0, 0, 0}), + MakeAttribute("dilations", std::vector{2, 2})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 2D dilated convolution (with padding) + { + TensorShape x_shape({2, 1, 7, 5}); + TensorShape w_shape({1, 1, 3, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 9, 7}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3, 3}), + MakeAttribute("pads", std::vector{1, 1, 1, 1}), + MakeAttribute("dilations", std::vector{2, 2})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 3D convolution + { + TensorShape x_shape({2, 1, 5, 5, 5}); + TensorShape w_shape({1, 1, 3, 3, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 5, 5, 5}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3, 3, 3}), + MakeAttribute("pads", std::vector{1, 1, 1, 1, 1, 1})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 3D strided convolution + { + TensorShape x_shape({2, 1, 7, 5, 5}); + TensorShape w_shape({1, 1, 3, 3, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 13, 9, 9}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3, 3, 3}), + MakeAttribute("pads", std::vector{1, 1, 1, 1, 1, 1}), + MakeAttribute("strides", std::vector{2, 2, 2})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } +} + +TEST(GradientCheckerTest, ConvTransposeGrad) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + ConvTransposeGradientCheckerTest(&execution_providers); +} + +// TODO: Enable test for ROCM +TEST(GradientCheckerTest, ResizeGrad) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + const std::vector attributes = { + MakeAttribute("coordinate_transformation_mode", "half_pixel"), + MakeAttribute("cubic_coeff_a", -0.75f), + MakeAttribute("exclude_outside", static_cast(0)), + MakeAttribute("extrapolation_value", 0.0f), + MakeAttribute("mode", "linear"), + MakeAttribute("nearest_mode", "floor")}; + + float max_error; + GradientChecker gradient_checker; + OpDef op_def{"Resize", kOnnxDomain, 18}; + + TensorInfo x_info({1, 2, 4, 4}, true); + TensorInfo roi_info({4}, false, nullptr, DataTypeImpl::GetTensorType()); + TensorInfo scales_info({4}, false, nullptr, DataTypeImpl::GetTensorType()); + + TensorInfo y_info({1, 2, 8, 8}, true); + + std::vector> x_datas = {{0.2f, 0.4f, 0.6f, 0.8f, 0.2f, 0.4f, 0.6f, 0.8f, + 0.2f, 0.4f, 0.6f, 0.8f, 0.2f, 0.4f, 0.6f, 0.8f, + 0.2f, 0.4f, 0.6f, 0.8f, 0.2f, 0.4f, 0.6f, 0.8f, + 0.2f, 0.4f, 0.6f, 0.8f, 0.2f, 0.4f, 0.6f, 0.8f}, + {1.0f, 1.0f, 1.0f, 1.0f}, + {1.0f, 1.0f, 2.0f, 2.0f}}; + + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info, roi_info, scales_info}, + {y_info}, &max_error, x_datas, attributes, true, false, &execution_providers)); + EXPECT_IS_TINY(max_error); +} + +#endif // USE_CUDA + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index b62e959556..0e88ce8e6c 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1773,13 +1773,17 @@ def run_step(model, input): _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) -def test_aten_upsample_bilinear(): +@pytest.mark.parametrize("interpolate_size_scale", ({"size": (8, 12)}, {"scale_factor": 4.7})) +@pytest.mark.parametrize("align_corners", (True, False)) +def test_resize_grad_correctness_bilinear_2d(interpolate_size_scale, align_corners): class _NeuralNetUpsampleBilinear(torch.nn.Module): def __init__(self): super().__init__() def forward(self, input): - return torch.nn.functional.interpolate(input, size=(8, 12), mode="bilinear") + return torch.nn.functional.interpolate( + input, align_corners=align_corners, mode="bilinear", **interpolate_size_scale + ) device = "cuda" pt_model = _NeuralNetUpsampleBilinear().to(device) @@ -4002,6 +4006,7 @@ def forward(self, bool_argument, input1): ], ) def test_unused_parameters(model, none_pt_params): + torch.manual_seed(2333) device = "cuda" N, D_in, H1, H2, D_out = 64, 784, 500, 400, 10 # noqa: F841, N806 @@ -6096,7 +6101,6 @@ def run_step(model, input, positions): found_missing_inference_log = False for record in caplog.records: msg = record.getMessage() - print(msg) if "The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing" in msg: found_missing_inference_log = True break @@ -6205,3 +6209,167 @@ def run_step(model, x): _test_helpers.assert_values_are_close(pt_prediction, ort_prediction) _test_helpers.assert_values_are_close(pt_loss, ort_loss) _test_helpers.assert_values_are_close(pt_x.grad, ort_x.grad) + + +@pytest.mark.skipif( + os.getenv("ORTMODULE_ROCM_TEST", "0") == "1", reason="Skip for ROCm because the kernel is not implemented for ROCm" +) +@pytest.mark.parametrize("use_fp16", [False, True]) +@pytest.mark.parametrize("conv_algo_search", [None, "EXHAUSTIVE", "HEURISTIC"]) +def test_conv_transpose_gradient(use_fp16, conv_algo_search): + class ChainedTransposedConv(nn.Module): + def __init__(self): + super().__init__() + + # Transposed Convolution 1D + self.conv1d_transpose = nn.ConvTranspose1d( + in_channels=4, out_channels=2, kernel_size=3, stride=2, padding=1 + ) + self.relu1 = nn.ReLU() + + # Transposed Convolution 2D + self.conv2d_transpose = nn.ConvTranspose2d( + in_channels=2, out_channels=3, kernel_size=3, stride=2, padding=1 + ) + self.relu2 = nn.ReLU() + + # Transposed Convolution 3D + self.conv3d_transpose = nn.ConvTranspose3d( + in_channels=3, out_channels=4, kernel_size=3, stride=2, padding=1 + ) + self.relu3 = nn.ReLU() + + def forward(self, x): + out1d = self.relu1(self.conv1d_transpose(x)) + out2d = self.relu2(self.conv2d_transpose(out1d.unsqueeze(2))) + out3d = self.relu3(self.conv3d_transpose(out2d.unsqueeze(2))) + return out3d.squeeze(2) + + if conv_algo_search is not None: + os.environ["ORTMODULE_CONV_ALGO_SEARCH"] = conv_algo_search + + def run_step(model, x): + with amp.autocast(use_fp16): + loss = model(x).sum() + loss.backward() + + return ( + x.grad, + model.conv1d_transpose.weight.grad, + model.conv1d_transpose.bias.grad, + model.conv2d_transpose.weight.grad, + model.conv2d_transpose.bias.grad, + model.conv3d_transpose.weight.grad, + model.conv3d_transpose.bias.grad, + ) + + device = "cuda" + pt_model = ChainedTransposedConv().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + pt_x = torch.randn(1, 4, 8, requires_grad=True, device=device) + ort_x = copy.deepcopy(pt_x) + + pt_grads = run_step(pt_model, pt_x) + ort_grads = run_step(ort_model, ort_x) + + for pt_grad, ort_grad in zip(pt_grads, ort_grads): + if use_fp16: + assert torch.allclose(pt_grad, ort_grad, atol=1e-3, rtol=1e-3) + else: + assert torch.allclose(pt_grad, ort_grad) + + if conv_algo_search is not None: + del os.environ["ORTMODULE_CONV_ALGO_SEARCH"] + + +@pytest.mark.skipif( + os.getenv("ORTMODULE_ROCM_TEST", "0") == "1", reason="Skip for ROCm because the kernel is not implemented for ROCm" +) +@pytest.mark.parametrize("conv_algo_search", [None, "EXHAUSTIVE", "HEURISTIC"]) +def test_conv_transpose_gradient_with_groups(conv_algo_search): + class TransposedConv3DWithGroups(nn.Module): + def __init__(self): + super().__init__() + # in_channels, out_channels, kernel_size, stride, padding + self.conv_transpose = nn.ConvTranspose3d( + in_channels=6, out_channels=4, kernel_size=3, stride=2, padding=1, groups=2 + ) + + def forward(self, x): + return self.conv_transpose(x) + + if conv_algo_search is not None: + os.environ["ORTMODULE_CONV_ALGO_SEARCH"] = conv_algo_search + + def run_step(model, x): + loss = model(x).sum() + loss.backward() + + return ( + x.grad, + model.conv_transpose.weight.grad, + model.conv_transpose.bias.grad, + ) + + device = "cuda" + pt_model = TransposedConv3DWithGroups().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + pt_x = torch.randn(1, 6, 8, 16, 16, requires_grad=True, device=device) + ort_x = copy.deepcopy(pt_x) + + pt_grads = run_step(pt_model, pt_x) + ort_grads = run_step(ort_model, ort_x) + + for pt_grad, ort_grad in zip(pt_grads, ort_grads): + assert torch.allclose(pt_grad, ort_grad) + + if conv_algo_search is not None: + del os.environ["ORTMODULE_CONV_ALGO_SEARCH"] + + +@pytest.mark.skipif( + os.getenv("ORTMODULE_ROCM_TEST", "0") == "1", reason="Skip for ROCm because the kernel is not implemented for ROCm" +) +@pytest.mark.parametrize("conv_algo_search", [None, "EXHAUSTIVE", "HEURISTIC"]) +def test_conv_transpose_gradient_with_strides_padding_and_dilation(conv_algo_search): + class ConvTransposeComplexModel(nn.Module): + def __init__(self): + super().__init__() + self.conv_transpose = nn.ConvTranspose3d( + 16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2), dilation=(1, 2, 1) + ) + self.param = nn.Parameter(torch.randn(20, 33, 21, 50, 97)) + + def forward(self, x): + return self.conv_transpose(x) * self.param + + if conv_algo_search is not None: + os.environ["ORTMODULE_CONV_ALGO_SEARCH"] = conv_algo_search + + def run_step(model, x): + loss = model(x).sum() + loss.backward() + + return ( + x.grad, + model.conv_transpose.weight.grad, + model.conv_transpose.bias.grad, + ) + + device = "cuda" + pt_model = ConvTransposeComplexModel().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)).to(device) + + pt_x = torch.randn(20, 16, 10, 50, 100, requires_grad=True, device=device) + ort_x = copy.deepcopy(pt_x) + + pt_grads = run_step(pt_model, pt_x) + ort_grads = run_step(ort_model, ort_x) + + for pt_grad, ort_grad in zip(pt_grads, ort_grads): + assert torch.allclose(pt_grad, ort_grad, atol=1e-2, rtol=1e-2) + + if conv_algo_search is not None: + del os.environ["ORTMODULE_CONV_ALGO_SEARCH"] diff --git a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py index 56338ddbaf..d5c37b3e36 100644 --- a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py +++ b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py @@ -360,14 +360,18 @@ def test_add_get_property(property_value): if isinstance(property_value, float): property_value = float(np.float32(property_value)) - state["property"] = property_value - assert "property" in state - assert state["property"] == property_value + assert len(state.properties) == 0 + + state.properties["property"] = property_value + assert "property" in state.properties + assert state.properties["property"] == property_value + assert len(state.properties) == 1 CheckpointState.save_checkpoint(state, checkpoint_file_path) new_state = CheckpointState.load_checkpoint(checkpoint_file_path) - assert "property" in new_state - assert new_state["property"] == property_value + assert "property" in new_state.properties + assert new_state.properties["property"] == property_value + assert len(new_state.properties) == 1 def test_get_input_output_names(): @@ -563,3 +567,60 @@ def test_eval_step_with_ort_values(): fetches = model(inputs, labels) assert isinstance(fetches, OrtValue) assert fetches + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_get_and_set_parameter_values(device): + with tempfile.TemporaryDirectory() as temp_dir: + ( + checkpoint_file_path, + training_model_file_path, + eval_model_file_path, + _, + pt_model, + ) = _create_training_artifacts( + temp_dir, requires_grad=["fc2.weight", "fc2.bias"], frozen_params=["fc1.weight", "fc1.bias"] + ) + + state = CheckpointState.load_checkpoint(checkpoint_file_path) + + model = Module(training_model_file_path, state, eval_model_file_path, device=device) + + state_dict = pt_model.state_dict() + assert len(state_dict) == len(state.parameters) + for parameter_name, _ in state.parameters: + assert parameter_name in state_dict + + for name, pt_param in pt_model.named_parameters(): + ort_param = state.parameters[name] + assert ort_param.name == name + assert np.allclose(pt_param.detach().cpu().numpy(), ort_param.data) + if name in ["fc1.weight", "fc1.bias"]: + assert ort_param.requires_grad is False + assert ort_param.grad is None + else: + assert ort_param.requires_grad is True + assert np.allclose(ort_param.grad, np.zeros_like(ort_param.data, dtype=np.float32)) + + original_param = state.parameters["fc1.weight"].data + state.parameters["fc1.weight"].data = np.ones_like(state.parameters["fc1.weight"].data, dtype=np.float32) + updated_param = state.parameters["fc1.weight"].data + assert np.allclose(updated_param, np.ones_like(updated_param, dtype=np.float32)) + + model.train() + inputs = torch.randn(64, 784).numpy() + labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() + loss = model(inputs, labels) + assert loss is not None + for name, _ in pt_model.named_parameters(): + ort_param = state.parameters[name] + assert ort_param.name == name + if name in ["fc1.weight", "fc1.bias"]: + assert ort_param.requires_grad is False + assert ort_param.grad is None + else: + assert ort_param.requires_grad is True + assert ort_param.grad.any() + + state.parameters["fc1.weight"] = original_param + assert np.allclose(state.parameters["fc1.weight"].data, original_param) diff --git a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc index 4fa3844717..1369c9c698 100644 --- a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc +++ b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc @@ -331,9 +331,12 @@ TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad) { #if defined(USE_CUDA) providers.push_back(onnxruntime::test::DefaultCudaExecutionProvider()); #endif - auto model = std::make_unique(model_uri, &state, session_option, + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::optional(onnxruntime::ToUTF8String(optim_uri))); + auto model = std::make_unique(model_identifier, &state, session_option, *env, providers); - auto optimizer = std::make_unique(optim_uri, &state, session_option, + auto optimizer = std::make_unique(model_identifier, &state, session_option, *env, providers); // Remove the temporary directory if it already exists. diff --git a/orttraining/orttraining/test/training_api/core/training_api_tests.cc b/orttraining/orttraining/test/training_api/core/training_api_tests.cc index ec0c7a1968..2170f7957e 100644 --- a/orttraining/orttraining/test/training_api/core/training_api_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_api_tests.cc @@ -76,9 +76,12 @@ void TestModuleExport(const std::vector>& pr std::unique_ptr env; ASSERT_STATUS_OK(Environment::Create(nullptr, env)); + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(training_model_uri), + std::optional(onnxruntime::ToUTF8String(eval_model_uri)), + std::nullopt); auto model = std::make_unique( - ToUTF8String(training_model_uri), &state, onnxruntime::SessionOptions(), - *env, providers, ToUTF8String(eval_model_uri)); + model_identifier, &state, onnxruntime::SessionOptions(), + *env, providers); auto test_dir = ORT_TSTR("export_model_for_inferencing_test_dir"); if (Env::Default().FolderExists(test_dir)) { @@ -141,7 +144,9 @@ TEST(TrainingApiTest, ModuleParametersSize) { onnxruntime::SessionOptions session_option; std::unique_ptr env; ASSERT_STATUS_OK(Environment::Create(nullptr, env)); - auto model = std::make_unique(ToUTF8String(model_uri), + auto model_identifiers = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, std::nullopt); + auto model = std::make_unique(model_identifiers, &state, session_option, *env, std::vector>()); size_t params_size = 0; @@ -164,7 +169,10 @@ TEST(TrainingApiTest, ModuleCopyBufferToParameters) { onnxruntime::SessionOptions session_option; std::unique_ptr env; ASSERT_STATUS_OK(Environment::Create(nullptr, env)); - auto model = std::make_unique(ToUTF8String(model_uri), + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::nullopt); + auto model = std::make_unique(model_identifier, &state, session_option, *env, std::vector>()); int64_t params_size = static_cast(model->GetParametersSize()); @@ -202,7 +210,10 @@ TEST(TrainingApiTest, ModuleTrainStep) { onnxruntime::SessionOptions session_option; std::unique_ptr env; ASSERT_STATUS_OK(Environment::Create(nullptr, env)); - auto model = std::make_unique(ToUTF8String(model_uri), + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::nullopt); + auto model = std::make_unique(model_identifier, &state, session_option, *env, std::vector>()); ASSERT_EQ(model->GetTrainingModelOutputCount(), 1); @@ -274,8 +285,12 @@ TEST(TrainingApiTest, OptimizerCreatedWithOptimizerCheckpointState) { ASSERT_STATUS_OK(Environment::Create(nullptr, env)); + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::optional(onnxruntime::ToUTF8String(optim_uri))); + std::shared_ptr model = std::make_shared( - ToUTF8String(model_uri), &state, session_option, + model_identifier, &state, session_option, *env, providers); // Load state dict from faked optimizer checkpoint state. @@ -285,7 +300,7 @@ TEST(TrainingApiTest, OptimizerCreatedWithOptimizerCheckpointState) { {"momentum0", "momentum1"}, external_optimizer_checkpoint_state)); std::shared_ptr optim = std::make_shared( - ToUTF8String(optim_uri), &new_state, session_option, *env, providers); + model_identifier, &new_state, session_option, *env, providers); ASSERT_TRUE(optim.get() != nullptr); } @@ -320,8 +335,12 @@ void TestLRSchduler(const std::basic_string& test_file_name, ASSERT_STATUS_OK(Environment::Create(nullptr, env)); + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::optional(onnxruntime::ToUTF8String(optim_uri))); + std::shared_ptr model = std::make_shared( - ToUTF8String(model_uri), &state, session_option, + model_identifier, &state, session_option, *env, providers); OrtValue input, target; @@ -351,7 +370,7 @@ void TestLRSchduler(const std::basic_string& test_file_name, } std::shared_ptr optim = std::make_shared( - ToUTF8String(optim_uri), &state, session_option, + model_identifier, &state, session_option, *env, providers); // KNOWN ISSUE: LinearLRScheduler by default use optim's states to calculate the first step's learning rate. @@ -445,11 +464,15 @@ TEST(TrainingApiTest, OptimStep) { providers.push_back(onnxruntime::test::DefaultCudaExecutionProvider()); #endif ASSERT_STATUS_OK(Environment::Create(nullptr, env)); + + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::optional(onnxruntime::ToUTF8String(optim_uri))); auto model = std::make_unique( - ToUTF8String(model_uri), &state, session_option, + model_identifier, &state, session_option, *env, providers); auto optim = std::make_unique( - ToUTF8String(optim_uri), &state, session_option, + model_identifier, &state, session_option, *env, providers); OrtValue input, target; diff --git a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc index e864f3b863..e46952d87c 100644 --- a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "gtest/gtest.h" +#include "gmock/gmock.h" #include "onnxruntime_c_api.h" #include "onnxruntime_training_c_api.h" @@ -16,6 +17,7 @@ namespace onnxruntime::training::test { #define MODEL_FOLDER ORT_TSTR("testdata/training_api/") +#define ORT_FORMAT_MODEL_FOLDER ORT_TSTR("testdata/training_api/ort_format/") TEST(TrainingCApiTest, SaveCheckpoint) { auto model_uri = MODEL_FOLDER "training_model.onnx"; @@ -220,4 +222,202 @@ TEST(TrainingCApiTest, RegisterCustomOps) { ASSERT_TRUE(loss.front().IsTensor()); } +TEST(TrainingCApiTest, LoadModelsAndCreateSession) { + auto model_path = MODEL_FOLDER "training_model.onnx"; + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, + Ort::SessionOptions(), + checkpoint_state, + model_path); +} + +TEST(TrainingCApiTest, LoadModelsAndCreateSession_ORTFormat) { + auto train_model_path = ORT_FORMAT_MODEL_FOLDER "training_model.ort"; + auto eval_train_model_path = ORT_FORMAT_MODEL_FOLDER "eval_model.ort"; + auto optimizer_model_path = ORT_FORMAT_MODEL_FOLDER "optimizer_model.ort"; + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(ORT_FORMAT_MODEL_FOLDER "checkpoint"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, + Ort::SessionOptions(), + checkpoint_state, + train_model_path, + eval_train_model_path, + optimizer_model_path); +} + +TEST(TrainingCApiTest, LoadONNXModelsFromBuffer) { + auto model_path = MODEL_FOLDER "training_model.onnx"; + size_t model_data_len = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(model_path, model_data_len)); + std::vector train_model_data(model_data_len); + std::ifstream bytes_stream(model_path, std::ifstream::in | std::ifstream::binary); + bytes_stream.read(reinterpret_cast(train_model_data.data()), model_data_len); + ASSERT_TRUE(train_model_data.size() == model_data_len); + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, + Ort::SessionOptions(), + checkpoint_state, + train_model_data); +} + +TEST(TrainingCApiTest, LoadORTFormatModelsFromBuffer) { + auto train_model_path = ORT_FORMAT_MODEL_FOLDER "training_model.ort"; + auto eval_model_path = ORT_FORMAT_MODEL_FOLDER "eval_model.ort"; + auto optimizer_model_path = ORT_FORMAT_MODEL_FOLDER "optimizer_model.ort"; + size_t model_data_len = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(train_model_path, model_data_len)); + std::vector train_model_data(model_data_len); + { + std::ifstream bytes_stream(train_model_path, std::ifstream::in | std::ifstream::binary); + bytes_stream.read(reinterpret_cast(train_model_data.data()), model_data_len); + ASSERT_TRUE(train_model_data.size() == model_data_len); + } + + model_data_len = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(eval_model_path, model_data_len)); + std::vector eval_model_data(model_data_len); + { + std::ifstream bytes_stream(eval_model_path, std::ifstream::in | std::ifstream::binary); + bytes_stream.read(reinterpret_cast(eval_model_data.data()), model_data_len); + ASSERT_TRUE(eval_model_data.size() == model_data_len); + } + + model_data_len = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(optimizer_model_path, model_data_len)); + std::vector optimizer_model_data(model_data_len); + { + std::ifstream bytes_stream(optimizer_model_path, std::ifstream::in | std::ifstream::binary); + bytes_stream.read(reinterpret_cast(optimizer_model_data.data()), model_data_len); + ASSERT_TRUE(optimizer_model_data.size() == model_data_len); + } + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(ORT_FORMAT_MODEL_FOLDER "checkpoint"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), + checkpoint_state, train_model_data, + eval_model_data, optimizer_model_data); +} + +TEST(TrainingCApiTest, LoadModelsFromBufferThrows) { + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + + try { + std::vector train_model_data; + Ort::TrainingSession training_session = Ort::TrainingSession(env, + Ort::SessionOptions(), + checkpoint_state, + train_model_data); + } catch (const std::exception& ex) { + ASSERT_THAT(ex.what(), + testing::HasSubstr("Training Session Creation failed. Train model data cannot be NULL.")); + } +} + +TEST(TrainingCApiTest, GetParameter) { + auto model_uri = MODEL_FOLDER "training_model.onnx"; + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri); + + Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight"); + auto tensor_info = parameter.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + ASSERT_EQ(shape.size(), 2U); + ASSERT_EQ(shape.front(), static_cast(500)); + ASSERT_EQ(shape.back(), static_cast(784)); +} + +TEST(TrainingCApiTest, UpdateParameter) { + auto model_uri = MODEL_FOLDER "training_model.onnx"; + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri); + + Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight"); + auto tensor_info = parameter.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + ASSERT_EQ(shape.size(), 2U); + ASSERT_EQ(shape.front(), static_cast(500)); + ASSERT_EQ(shape.back(), static_cast(784)); + + OrtValue* updated_param_value = std::make_unique().release(); + GenerateRandomInput(std::array{500, 784}, *updated_param_value); + Ort::Value updated_parameter{updated_param_value}; + checkpoint_state.UpdateParameter("fc1.weight", updated_parameter); + + Ort::Value current_parameter = checkpoint_state.GetParameter("fc1.weight"); + gsl::span actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span not_expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); + + checkpoint_state.UpdateParameter("fc1.weight", parameter); + current_parameter = checkpoint_state.GetParameter("fc1.weight"); + actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + not_expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); +} + +#ifdef USE_CUDA +TEST(TrainingCApiTest, UpdateParameterDifferentDevices) { + auto model_uri = MODEL_FOLDER "training_model.onnx"; + + Ort::Env env; + Ort::SessionOptions session_options; + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, session_options, checkpoint_state, model_uri); + + Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight"); + auto tensor_info = parameter.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + ASSERT_EQ(shape.size(), 2U); + ASSERT_EQ(shape.front(), static_cast(500)); + ASSERT_EQ(shape.back(), static_cast(784)); + + OrtValue* updated_param_value = std::make_unique().release(); + GenerateRandomInput(std::array{500, 784}, *updated_param_value); + Ort::Value updated_parameter{updated_param_value}; + checkpoint_state.UpdateParameter("fc1.weight", updated_parameter); + + Ort::Value current_parameter = checkpoint_state.GetParameter("fc1.weight"); + gsl::span actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span not_expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); + + checkpoint_state.UpdateParameter("fc1.weight", parameter); + current_parameter = checkpoint_state.GetParameter("fc1.weight"); + actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + not_expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); +} +#endif + } // namespace onnxruntime::training::test diff --git a/orttraining/orttraining/test/training_ops/cuda/conv_transpose_grad_test.cc b/orttraining/orttraining/test/training_ops/cuda/conv_transpose_grad_test.cc new file mode 100644 index 0000000000..18c5ff9437 --- /dev/null +++ b/orttraining/orttraining/test/training_ops/cuda/conv_transpose_grad_test.cc @@ -0,0 +1,360 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime::contrib::test { + +using namespace onnxruntime::test; + +#if USE_CUDA +namespace { + +struct ConvTransposeGradOpAttributes { + std::vector dilations; + int64_t group; + std::vector kernel_shape; + std::vector pads; + std::vector strides; +}; + +void TestConvTransposeGradOp(const ConvTransposeGradOpAttributes& attributes, + const std::vector>& inputs, + const std::vector>& input_shapes, + const std::vector>& outputs, + const std::vector>& output_shapes, + bool is_half = false) { + OpTester test("ConvTransposeGrad", 1, kMSDomain); + test.AddAttribute("group", attributes.group); + test.AddAttribute("kernel_shape", attributes.kernel_shape); + test.AddAttribute("pads", attributes.pads); + + if (!attributes.dilations.empty()) { + test.AddAttribute("dilations", attributes.dilations); + } + + if (!attributes.strides.empty()) { + test.AddAttribute("strides", attributes.strides); + } + + if (is_half) { + std::vector dY_half(inputs[0].size()); + ConvertFloatToMLFloat16(inputs[0].data(), dY_half.data(), static_cast(inputs[0].size())); + test.AddInput("dY", input_shapes[0], dY_half); + + std::vector X_half(inputs[1].size()); + ConvertFloatToMLFloat16(inputs[1].data(), X_half.data(), static_cast(inputs[1].size())); + test.AddInput("X", input_shapes[1], X_half); + + std::vector W_half(inputs[2].size()); + ConvertFloatToMLFloat16(inputs[2].data(), W_half.data(), static_cast(inputs[2].size())); + test.AddInput("W", input_shapes[2], W_half); + + std::vector dX_half(outputs[0].size()); + ConvertFloatToMLFloat16(outputs[0].data(), dX_half.data(), static_cast(outputs[0].size())); + test.AddOutput("dX", output_shapes[0], dX_half); + + std::vector dW_half(outputs[1].size()); + ConvertFloatToMLFloat16(outputs[1].data(), dW_half.data(), static_cast(outputs[1].size())); + test.AddOutput("dW", output_shapes[1], dW_half); + + if (outputs.size() >= 3) { + std::vector dB_half(outputs[2].size()); + ConvertFloatToMLFloat16(outputs[2].data(), dB_half.data(), static_cast(outputs[2].size())); + test.AddOutput("dB", output_shapes[2], dB_half); + } + } else { + test.AddInput("dY", input_shapes[0], inputs[0]); + test.AddInput("X", input_shapes[1], inputs[1]); + test.AddInput("W", input_shapes[2], inputs[2]); + + test.AddOutput("dX", output_shapes[0], outputs[0]); + test.AddOutput("dW", output_shapes[1], outputs[1]); + + if (outputs.size() >= 3) { + test.AddOutput("dB", output_shapes[2], outputs[2]); + } + } + + test.Run(); +} + +} // namespace + +TEST(ConvTransposeGradTest, ConvTranspose1DDefaultAttributes) { + ConvTransposeGradOpAttributes attrs = { + std::vector{1}, // dilations + 1, // group + std::vector{2}, // kernel_shape + std::vector{0, 0}, // pads + std::vector{1}, // strides + }; + + std::vector dY(12, 1.0f); + std::vector dY_shape = {1, 2, 6}; + std::vector X = {0.1868f, -0.1679f, 1.2677f, 2.1288f, -0.0331f, + 1.0454f, 0.7722f, 0.2963f, -0.8684f, -0.0547f}; + std::vector X_shape = {1, 2, 5}; + std::vector W = {0.0847f, -0.0066f, + 0.1212f, 0.2317f, + -0.4975f, 0.2762f, + -0.2644f, 0.3210f}; + std::vector W_shape = {2, 2, 2}; + std::vector dX = {0.4309f, 0.4309f, 0.4309f, 0.4309f, 0.4309f, + -0.1647f, -0.1647f, -0.1647f, -0.1647f, -0.1647f}; + std::vector dX_shape = X_shape; + std::vector dW = {3.3823f, 3.3823f, + 3.3823f, 3.3823f, + 1.1908f, 1.1908f, + 1.1908f, 1.1908f}; + std::vector dW_shape = W_shape; + std::vector dB = {6.f, 6.f}; + std::vector dB_shape = {2}; + + for (const bool is_half : {false, true}) + TestConvTransposeGradOp( + attrs, // attributes + {dY, X, W}, // inputs + {dY_shape, X_shape, W_shape}, // input shapes + {dX, dW, dB}, // outputs + {dX_shape, dW_shape, dB_shape}, // output shapes + is_half); +} + +TEST(ConvTransposeGradTest, ConvTranspose1DStrideAndPadding) { + ConvTransposeGradOpAttributes attrs = { + std::vector{1}, // dilations + 1, // group + std::vector{2}, // kernel_shape + std::vector{2, 2}, // pads + std::vector{2}, // strides + }; + + std::vector dY(12, 1.0f); + std::vector dY_shape = {1, 2, 6}; + std::vector X = {-0.0254f, -1.4303f, -0.1568f, 1.2318f, -0.8365f, + 2.0836f, -1.0181f, -0.7539f, 0.4484f, -0.5799f}; + std::vector X_shape = {1, 2, 5}; + std::vector W = {-0.1438f, 0.2386f, + -0.3085f, 0.1149f, + -0.1653f, -0.0707f, + -0.1479f, -0.0918f}; + std::vector W_shape = {2, 2, 2}; + std::vector dX = {0.0000f, -0.0988f, -0.0988f, -0.0988f, 0.0000f, + 0.0000f, -0.4757f, -0.4757f, -0.4757f, 0.0000f}; + std::vector dX_shape = X_shape; + std::vector dW = {-0.3553f, -0.3553f, + -0.3553f, -0.3553f, + -1.3236f, -1.3236f, + -1.3236f, -1.3236f}; + std::vector dW_shape = W_shape; + std::vector dB = {6.f, 6.f}; + std::vector dB_shape = {2}; + + for (const bool is_half : {false, true}) + TestConvTransposeGradOp( + attrs, // attributes + {dY, X, W}, // inputs + {dY_shape, X_shape, W_shape}, // input shapes + {dX, dW, dB}, // outputs + {dX_shape, dW_shape, dB_shape}, // output shapes + is_half); +} + +TEST(ConvTransposeGradTest, ConvTranspose1D) { + ConvTransposeGradOpAttributes attrs = { + std::vector{2}, // dilations + 2, // group + std::vector{3}, // kernel_shape + std::vector{2, 2}, // pads + std::vector{2}, // strides + }; + + std::vector dY(38, 1.0f); + std::vector dY_shape = {1, 2, 19}; + std::vector X = {0.2816f, 1.4660f, 0.1002f, -0.2460f, -0.1027f, 0.1228f, -0.8516f, -1.0246f, -0.6576f, -1.0280f, + 0.1093f, 0.1447f, 1.1279f, 0.1085f, -0.3438f, -0.6224f, -0.0902f, 2.2791f, -2.1910f, 1.9736f}; + std::vector X_shape = {1, 2, 10}; + std::vector W = {-0.1050f, -0.0622f, -0.3632f, + -0.3861f, -0.0134f, -0.0277f}; + std::vector W_shape = {2, 1, 3}; + std::vector dX = {-0.4254f, -0.5304f, -0.5304f, -0.5304f, -0.5304f, -0.5304f, -0.5304f, -0.5304f, -0.5304f, -0.1672f, + -0.0411f, -0.4272f, -0.4272f, -0.4272f, -0.4272f, -0.4272f, -0.4272f, -0.4272f, -0.4272f, -0.3995f}; + std::vector dX_shape = X_shape; + std::vector dW = {-2.2215f, -1.9400f, -0.9120f, + 2.3863f, 2.4956f, 0.5220f}; + std::vector dW_shape = W_shape; + std::vector dB = {19.f, 19.f}; + std::vector dB_shape = {2}; + + for (const bool is_half : {false, true}) + TestConvTransposeGradOp( + attrs, // attributes + {dY, X, W}, // inputs + {dY_shape, X_shape, W_shape}, // input shapes + {dX, dW, dB}, // outputs + {dX_shape, dW_shape, dB_shape}, // output shapes + is_half); +} + +TEST(ConvTransposeGradTest, ConvTranspose2DDefaultAttributes) { + ConvTransposeGradOpAttributes attrs = { + std::vector{1, 1}, // dilations + 1, // group + std::vector{3, 3}, // kernel_shape + std::vector{0, 0, 0, 0}, // pads + std::vector{1, 1}, // strides + }; + + std::vector dY(98, 1.0f); + std::vector dY_shape = {1, 2, 7, 7}; + std::vector X = {1.1371f, -0.1498f, -1.7541f, -0.7585f, 1.6009f, -0.7496f, 0.1535f, -0.2533f, -1.0811f, 0.9760f, + -0.2528f, 0.1820f, -1.7450f, 0.1632f, -0.3469f, 1.1150f, -2.6888f, -0.1632f, -0.3269f, 0.6904f, + 1.3036f, 0.7883f, 0.4459f, 0.1223f, 0.1576f, -0.8187f, 0.2281f, 1.5320f, 1.2643f, -0.5163f, + 1.0677f, -0.2141f, 1.2992f, -2.1865f, -0.6346f, 0.8938f, 0.8346f, -2.7397f, 0.9223f, 0.8166f, + 1.1736f, -1.3644f, 0.0316f, -1.2904f, 0.7062f, 0.2470f, 0.4559f, 0.8493f, 1.0519f, 0.9915f}; + std::vector X_shape = {1, 2, 5, 5}; + std::vector W = {0.0761f, 0.0270f, -0.1677f, 0.1803f, -0.0824f, -0.0285f, + 0.2098f, -0.0569f, -0.1514f, 0.0338f, -0.1962f, -0.2169f, + 0.0432f, -0.1977f, -0.0814f, -0.1866f, -0.1574f, -0.0198f, + 0.0097f, 0.0019f, -0.1204f, 0.2018f, -0.1750f, -0.0549f, + -0.0687f, -0.1269f, 0.1913f, 0.1331f, -0.0632f, 0.0821f, + 0.0127f, 0.1761f, -0.0883f, -0.1370f, 0.1472f, 0.0690f}; + std::vector W_shape = {2, 2, 3, 3}; + std::vector dX = {-0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, + -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, + -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, + 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, + 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f}; + std::vector dX_shape = X_shape; + std::vector dW = {-1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f, + -1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f, + -1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f, + 4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f, + 4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f, + 4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f}; + std::vector dW_shape = W_shape; + std::vector dB = {49.f, 49.f}; + std::vector dB_shape = {2}; + + for (const bool is_half : {false, true}) + TestConvTransposeGradOp( + attrs, // attributes + {dY, X, W}, // inputs + {dY_shape, X_shape, W_shape}, // input shapes + {dX, dW, dB}, // outputs + {dX_shape, dW_shape, dB_shape}, // output shapes + is_half); +} + +TEST(ConvTransposeGradTest, ConvTranspose2D) { + ConvTransposeGradOpAttributes attrs = { + std::vector{2, 2}, // dilations + 2, // group + std::vector{3, 3}, // kernel_shape + std::vector{2, 2, 2, 2}, // pads + std::vector{2, 2}, // strides + }; + + std::vector dY(162U, 1.0f); + std::vector dY_shape = {1, 2, 9, 9}; + std::vector X = {-1.0158f, 0.1709f, -0.1660f, 0.3881f, 0.4017f, 1.5497f, 1.1205f, 0.2553f, -0.4359f, -0.0467f, + 1.1374f, -0.0713f, 0.2248f, 0.8915f, -0.7239f, 0.1679f, -1.5604f, -0.8521f, 0.8966f, 3.3743f, + -0.5516f, 0.2516f, -0.4091f, -0.9868f, 0.3008f, 1.1066f, -0.7039f, -1.5273f, -0.3666f, 0.9392f, + 0.1264f, -1.6604f, -1.4810f, 0.6654f, -0.2007f, -1.0660f, -0.5420f, -0.7030f, 0.0411f, 2.1082f, + -0.7995f, 0.2422f, 1.2848f, -0.1747f, 1.7935f, -0.1123f, -0.6668f, -2.2383f, 1.5419f, -2.7614f}; + std::vector X_shape = {1, 2, 5, 5}; + std::vector W = {-0.2057f, -0.0411f, 0.0277f, 0.2221f, 0.1901f, 0.1435f, + -0.2249f, 0.3299f, -0.2203f, -0.1013f, -0.3326f, 0.1005f, + -0.0536f, 0.3067f, 0.3297f, 0.2728f, 0.1649f, -0.2548f}; + std::vector W_shape = {2, 1, 3, 3}; + std::vector dX = {0.4431f, 0.4403f, 0.4403f, 0.4403f, 0.5171f, 0.4297f, 0.2212f, 0.2212f, 0.2212f, 0.2704f, + 0.4297f, 0.2212f, 0.2212f, 0.2212f, 0.2704f, 0.4297f, 0.2212f, 0.2212f, 0.2212f, 0.2704f, + 0.3202f, 0.3366f, 0.3366f, 0.3366f, 0.1654f, 0.5465f, 0.7658f, 0.7658f, 0.7658f, 0.6908f, + 0.3144f, 0.4323f, 0.4323f, 0.4323f, 0.2569f, 0.3144f, 0.4323f, 0.4323f, 0.4323f, 0.2569f, + 0.3144f, 0.4323f, 0.4323f, 0.4323f, 0.2569f, 0.4043f, 0.2494f, 0.2494f, 0.2494f, -0.1808f}; + std::vector dX_shape = X_shape; + std::vector dW = {2.2293f, 4.5327f, 1.6281f, 3.0240f, 4.3115f, 1.0052f, + 3.8675f, 5.7067f, 2.7011f, -2.7512f, -4.6026f, -5.5423f, + -4.4098f, -5.1546f, -7.0335f, -0.2852f, -0.9177f, -5.5580f}; + std::vector dW_shape = W_shape; + std::vector dB = {81.f, 81.f}; + std::vector dB_shape = {2}; + + for (const bool is_half : {false, true}) + TestConvTransposeGradOp( + attrs, // attributes + {dY, X, W}, // inputs + {dY_shape, X_shape, W_shape}, // input shapes + {dX, dW, dB}, // outputs + {dX_shape, dW_shape, dB_shape}, // output shapes + is_half); +} + +TEST(ConvTransposeGradTest, ConvTranspose3D) { + ConvTransposeGradOpAttributes attrs = { + std::vector{2, 2, 2}, // dilations + 2, // group + std::vector{2, 2, 2}, // kernel_shape + std::vector{2, 2, 2, 2, 2, 2}, // pads + std::vector{2, 2, 2}, // strides + }; + + std::vector dY(250U, 1.0f); + std::vector dY_shape = {1, 2, 5, 5, 5}; + std::vector X = {-0.2396f, 0.4280f, -1.3505f, -0.4366f, -1.3296f, 0.3531f, 0.0645f, -1.5480f, + -1.7464f, -0.9160f, 1.5065f, -0.0788f, 0.0487f, 2.4641f, 0.3855f, 2.0499f, + 0.7068f, -0.8076f, -0.4442f, 0.1003f, -0.5056f, -0.1430f, -0.3744f, -0.2637f, + -1.1012f, 1.0213f, 0.0503f, 0.0147f, -0.3664f, 0.8834f, -1.1478f, -0.8221f, + -0.5649f, -0.4224f, -0.6779f, -0.9363f, 1.1972f, 0.2094f, 0.5676f, -0.2718f, + -0.1678f, -0.4178f, -0.4672f, 0.2777f, -0.7953f, -0.5603f, -2.8694f, 1.5743f, + -0.5057f, -0.2529f, 0.5894f, -0.3980f, -0.6719f, -0.3425f, 0.0821f, 0.8672f, + 0.7218f, 1.5519f, 1.6513f, -1.1956f, 0.8471f, 0.4295f, -1.3917f, -1.2202f, + 0.1054f, -2.2191f, -0.9546f, 1.1750f, -2.3637f, 1.6297f, -0.5796f, 0.3850f, + 0.9287f, -0.3492f, -0.7284f, 0.2987f, -0.7534f, 0.7747f, -1.3198f, -0.3633f, + 1.8635f, -0.3187f, 0.9032f, -0.6083f, -0.4236f, -0.1929f, -1.1715f, -0.5591f, + -1.8290f, -1.1503f, 0.1430f, 0.6048f, -0.3148f, 1.0638f, -0.2946f, -0.4990f, + -1.4443f, -0.7757f, -1.5374f, -0.4567f, -0.2998f, 0.0521f, 1.6293f, -0.6720f, + -0.0102f, -0.6598f, 0.5005f, 0.4203f, 1.3911f, 1.5988f, 0.3991f, 1.4931f, + 0.9741f, 0.3557f, 0.1088f, -1.1806f, 1.1115f, -1.3283f, 1.7235f, 0.4177f, + 0.7992f, -1.7248f, -0.5339f, -0.3153f, 0.1379f, 0.7493f, 0.3028f, -0.9473f}; + std::vector X_shape = {1, 2, 4, 4, 4}; + std::vector W = {-0.1093f, -0.0511f, 0.1132f, 0.3369f, -0.3531f, -0.1766f, 0.0628f, 0.2118f, + 0.3068f, 0.3217f, -0.2903f, -0.1633f, -0.3261f, -0.0990f, 0.2497f, -0.1553f}; + std::vector W_shape = {2, 1, 2, 2, 2}; + std::vector dX = {0.2118f, 0.2746f, 0.2746f, 0.0628f, 0.0352f, -0.2550f, -0.2550f, -0.2902f, + 0.0352f, -0.2550f, -0.2550f, -0.2902f, -0.1766f, -0.5297f, -0.5297f, -0.3531f, + 0.5487f, 0.7247f, 0.7247f, 0.1760f, 0.3210f, 0.0346f, 0.0346f, -0.2864f, + 0.3210f, 0.0346f, 0.0346f, -0.2864f, -0.2277f, -0.6901f, -0.6901f, -0.4624f, + 0.5487f, 0.7247f, 0.7247f, 0.1760f, 0.3210f, 0.0346f, 0.0346f, -0.2864f, + 0.3210f, 0.0346f, 0.0346f, -0.2864f, -0.2277f, -0.6901f, -0.6901f, -0.4624f, + 0.3369f, 0.4501f, 0.4501f, 0.1132f, 0.2858f, 0.2897f, 0.2897f, 0.0038f, + 0.2858f, 0.2897f, 0.2897f, 0.0038f, -0.0511f, -0.1604f, -0.1604f, -0.1093f, + -0.1553f, 0.0944f, 0.0944f, 0.2497f, -0.2542f, -0.3307f, -0.3307f, -0.0765f, + -0.2542f, -0.3307f, -0.3307f, -0.0765f, -0.0990f, -0.4251f, -0.4251f, -0.3261f, + -0.3185f, -0.3592f, -0.3592f, -0.0407f, -0.0958f, -0.1557f, -0.1557f, -0.0600f, + -0.0958f, -0.1557f, -0.1557f, -0.0600f, 0.2227f, 0.2035f, 0.2035f, -0.0193f, + -0.3185f, -0.3592f, -0.3592f, -0.0407f, -0.0958f, -0.1557f, -0.1557f, -0.0600f, + -0.0958f, -0.1557f, -0.1557f, -0.0600f, 0.2227f, 0.2035f, 0.2035f, -0.0193f, + -0.1633f, -0.4536f, -0.4536f, -0.2903f, 0.1584f, 0.1749f, 0.1749f, 0.0165f, + 0.1584f, 0.1749f, 0.1749f, 0.0165f, 0.3217f, 0.6285f, 0.6285f, 0.3068f}; + std::vector dX_shape = X_shape; + std::vector dW = {-2.3068f, -2.1096f, -0.4322f, 0.4820f, 1.5420f, -4.1569f, -4.9628f, -5.5716f, + 1.0492f, 1.6683f, -6.3262f, -3.2359f, 2.4532f, -2.3299f, -5.1917f, -9.2525f}; + std::vector dW_shape = W_shape; + std::vector dB = {125.f, 125.f}; + std::vector dB_shape = {2}; + + for (const bool is_half : {false, true}) + TestConvTransposeGradOp( + attrs, // attributes + {dY, X, W}, // inputs + {dY_shape, X_shape, W_shape}, // input shapes + {dX, dW, dB}, // outputs + {dX_shape, dW_shape, dB_shape}, // output shapes + is_half); +} +#endif // USE_CUDA + +} // namespace onnxruntime::contrib::test diff --git a/orttraining/orttraining/test/training_ops/cuda/resize_grad_test.cc b/orttraining/orttraining/test/training_ops/cuda/resize_grad_test.cc new file mode 100644 index 0000000000..8fc13af881 --- /dev/null +++ b/orttraining/orttraining/test/training_ops/cuda/resize_grad_test.cc @@ -0,0 +1,227 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test/providers/compare_provider_test_utils.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" + +namespace onnxruntime::test { + +#if defined(USE_CUDA) || defined(USE_ROCM) + +namespace { + +void AddResizeGradAttributes(OpTester& test, const std::string& coordinate_transformation_mode) { + test.AddAttribute("mode", "linear"); + test.AddAttribute("coordinate_transformation_mode", coordinate_transformation_mode); +} + +} // namespace + +TEST(ResizeGradTest, ResizeGradWithSizes) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "half_pixel"); + + std::vector dY(128, 1.0f); + std::vector dY_shape = {1, 2, 8, 8}; + + std::vector X(32, 1.0f); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX(32, 4.0f); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY); + test.AddInput("X", X_shape, X); + + test.AddOutput("dX", dX_shape, dX); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +TEST(ResizeGradTest, ResizeGradWithSizesHalf) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "half_pixel"); + + std::vector dY(128, 1.0f); + std::vector dY_half(dY.size()); + ConvertFloatToMLFloat16(dY.data(), dY_half.data(), static_cast(dY.size())); + std::vector dY_shape = {1, 2, 8, 8}; + + std::vector X(32, 1.0f); + std::vector X_half(X.size()); + ConvertFloatToMLFloat16(X.data(), X_half.data(), static_cast(X.size())); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX(32, 4.0f); + std::vector dX_half(dX.size()); + ConvertFloatToMLFloat16(dX.data(), dX_half.data(), static_cast(dX.size())); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY_half); + test.AddInput("X", X_shape, X_half); + + test.AddOutput("dX", dX_shape, dX_half); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +TEST(ResizeGradTest, ResizeGradWithSizesAndAlignCorners) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "align_corners"); + + std::vector dY(128, 1.0f); + std::vector dY_shape = {1, 2, 8, 8}; + + std::vector X(32, 1.0f); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX({2.9388f, 3.9184f, 3.9184f, 2.9388f, 3.9184f, 5.2245f, 5.2245f, 3.9184f, + 3.9184f, 5.2245f, 5.2245f, 3.9184f, 2.9388f, 3.9184f, 3.9184f, 2.9388f, + 2.9388f, 3.9184f, 3.9184f, 2.9388f, 3.9184f, 5.2245f, 5.2245f, 3.9184f, + 3.9184f, 5.2245f, 5.2245f, 3.9184f, 2.9388f, 3.9184f, 3.9184f, 2.9388f}); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY); + test.AddInput("X", X_shape, X); + + test.AddOutput("dX", dX_shape, dX); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +TEST(ResizeGradTest, ResizeGradWithScales) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "half_pixel"); + + std::vector dY(72, 1.0f); + std::vector dY_shape = {1, 2, 6, 6}; + + std::vector X(32, 1.0f); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX({2.7128f, 2.9550f, 2.7612f, 1.4533f, 2.9550f, 3.2189f, 3.0078f, 1.5830f, + 2.7612f, 3.0078f, 2.8106f, 1.4792f, 1.4533f, 1.5830f, 1.4792f, 0.7785f, + 2.7128f, 2.9550f, 2.7612f, 1.4533f, 2.9550f, 3.2189f, 3.0078f, 1.5830f, + 2.7612f, 3.0078f, 2.8106f, 1.4792f, 1.4533f, 1.5830f, 1.4792f, 0.7785f}); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY); + test.AddInput("X", X_shape, X); + test.AddInput("", {0}, {}); + test.AddInput("scales", {4}, {1.0f, 1.0f, 1.7f, 1.7f}); + + test.AddOutput("dX", dX_shape, dX); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +TEST(ResizeGradTest, ResizeGradWithScalesHalf) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "half_pixel"); + + std::vector dY(72, 1.0f); + std::vector dY_half(dY.size()); + ConvertFloatToMLFloat16(dY.data(), dY_half.data(), static_cast(dY.size())); + std::vector dY_shape = {1, 2, 6, 6}; + + std::vector X(32, 1.0f); + std::vector X_half(X.size()); + ConvertFloatToMLFloat16(X.data(), X_half.data(), static_cast(X.size())); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX({2.7128f, 2.9550f, 2.7612f, 1.4533f, 2.9550f, 3.2189f, 3.0078f, 1.5830f, + 2.7612f, 3.0078f, 2.8106f, 1.4792f, 1.4533f, 1.5830f, 1.4792f, 0.7785f, + 2.7128f, 2.9550f, 2.7612f, 1.4533f, 2.9550f, 3.2189f, 3.0078f, 1.5830f, + 2.7612f, 3.0078f, 2.8106f, 1.4792f, 1.4533f, 1.5830f, 1.4792f, 0.7785f}); + std::vector dX_half(dX.size()); + ConvertFloatToMLFloat16(dX.data(), dX_half.data(), static_cast(dX.size())); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY_half); + test.AddInput("X", X_shape, X_half); + test.AddInput("", {0}, {}); + test.AddInput("scales", {4}, {1.0f, 1.0f, 1.7f, 1.7f}); + + test.AddOutput("dX", dX_shape, dX_half); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +TEST(ResizeGradTest, ResizeGradWithScalesAndAlignCorners) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "align_corners"); + + std::vector dY(72, 1.0f); + std::vector dY_shape = {1, 2, 6, 6}; + + std::vector X(32, 1.0f); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX({1.9600f, 2.2400f, 2.2400f, 1.9600f, 2.2400f, 2.5600f, 2.5600f, 2.2400f, + 2.2400f, 2.5600f, 2.5600f, 2.2400f, 1.9600f, 2.2400f, 2.2400f, 1.9600f, + 1.9600f, 2.2400f, 2.2400f, 1.9600f, 2.2400f, 2.5600f, 2.5600f, 2.2400f, + 2.2400f, 2.5600f, 2.5600f, 2.2400f, 1.9600f, 2.2400f, 2.2400f, 1.9600f}); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY); + test.AddInput("X", X_shape, X); + test.AddInput("", {0}, {}); + test.AddInput("scales", {4}, {1.0f, 1.0f, 1.7f, 1.7f}); + + test.AddOutput("dX", dX_shape, dX); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +#endif // defined(USE_CUDA) || defined(USE_ROCM) + +} // namespace onnxruntime::test diff --git a/orttraining/orttraining/training_api/checkpoint_property.h b/orttraining/orttraining/training_api/checkpoint_property.h index d7b1e295df..3c38c99b31 100644 --- a/orttraining/orttraining/training_api/checkpoint_property.h +++ b/orttraining/orttraining/training_api/checkpoint_property.h @@ -22,10 +22,12 @@ struct PropertyBag { PropertyBag() = default; void AddProperty(const std::string& name, const PropertyDataType& val) { - ORT_ENFORCE(named_properties_.find(name) == named_properties_.end(), - "Duplicated property named ", name); - - named_properties_.insert({name, val}); + auto it = named_properties_.find(name); + if (it == named_properties_.end()) { + named_properties_.insert({name, val}); + } else { + it->second = val; + } } template diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h index b3042c449a..0e8544a763 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h @@ -190,7 +190,29 @@ struct OrtTrainingApi { ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path, _In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path, - _Outptr_ OrtTrainingSession** out); + _Outptr_result_maybenull_ OrtTrainingSession** out); + + /** \brief Create a training session that can be used to begin or resume training. + * This api provides a way to load all the training artifacts from buffers instead of files. + * + * \param[in] env Environment to be used for the training session. + * \param[in] options Session options that the user can customize for this training session. + * \param[in] checkpoint_state Training states that the training session uses as a starting point for training. + * \param[in] train_model_data Buffer containing the model data to be used to perform training + * \param[in] train_data_length Length of the buffer containing train_model_data + * \param[in] eval_model_data Buffer containing the model data to be used to perform evaluation + * \param[in] eval_data_length Length of the buffer containing eval_model_data + * \param[in] optim_model_data Buffer containing the model data to be used to perform weight update + * \param[in] optim_data_length Length of the buffer containing optim_model_data + * \param[out] out Created training session. + * + */ + ORT_API2_STATUS(CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env, + _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const void* train_model_data, size_t train_data_length, + _In_ const void* eval_model_data, size_t eval_data_length, + _In_ const void* optim_model_data, size_t optim_data_length, + _Outptr_result_maybenull_ OrtTrainingSession** out); /// @} @@ -586,14 +608,14 @@ struct OrtTrainingApi { /// \name Accessing The Training Session State /// @{ - /** \brief Adds the given property to the checkpoint state. + /** \brief Adds or updates the given property to/in the checkpoint state. * * Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint - * state by the user if they desire by calling this function with the appropriate property name and - * value. The given property name must be unique to be able to successfully add the property. + * state by the user by calling this function with the corresponding property name and value. + * The given property name must be unique to be able to successfully add the property. * * \param[in] checkpoint_state The checkpoint state which should hold the property. - * \param[in] property_name Unique name of the property being added. + * \param[in] property_name Name of the property being added or updated. * \param[in] property_type Type of the property associated with the given name. * \param[in] property_value Property value associated with the given name. * @@ -610,7 +632,7 @@ struct OrtTrainingApi { * exist in the checkpoint state to be able to retrieve it successfully. * * \param[in] checkpoint_state The checkpoint state that is currently holding the property. - * \param[in] property_name Unique name of the property being retrieved. + * \param[in] property_name Name of the property being retrieved. * \param[in] allocator Allocator used to allocate the memory for the property_value. * \param[out] property_type Type of the property associated with the given name. * \param[out] property_value Property value associated with the given name. @@ -647,6 +669,57 @@ struct OrtTrainingApi { ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer, _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state); + /** \brief Retrieves the type and shape information of the parameter associated with the given parameter name. + * + * This function retrieves the type and shape of the parameter associated with the given parameter name. + * The parameter must exist in the checkpoint state to be able to retrieve its type and shape information successfully. + * + * \param[in] checkpoint_state The checkpoint state. + * \param[in] parameter_name Name of the parameter being retrieved. + * \param[out] parameter_type_and_shape The type and shape of the parameter being retrieved. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape); + + /** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name. + * + * This function updates a model parameter in the checkpoint state with the given parameter data. + * The training session must be already created with the checkpoint state that contains the parameter + * being updated. The given parameter is copied over to the registered device for the training session. + * The parameter must exist in the checkpoint state to be able to update it successfully. + * + * \param[in] checkpoint_state The checkpoint state. + * \param[in] parameter_name Name of the parameter being updated. + * \param[in] parameter The parameter data that should replace the existing parameter data. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _In_ OrtValue* parameter); + + /** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name. + * + * This function retrieves the model parameter data from the checkpoint state for the given parameter name. + * The parameter is copied over and returned as an OrtValue. The training session must be already created + * with the checkpoint state that contains the parameter being retrieved. + * The parameter must exist in the checkpoint state to be able to retrieve it successfully. + * + * \param[in] checkpoint_state The checkpoint state. + * \param[in] parameter_name Name of the parameter being retrieved. + * \param[in] allocator Allocator used to allocate the memory for the parameter. + * \param[out] parameter The parameter data that is retrieved from the checkpoint state. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** parameter); + /// @} }; diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h index 5bfdfcc74e..218bef5242 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h @@ -112,13 +112,13 @@ class CheckpointState : public detail::Base { const std::basic_string& path_to_checkpoint, const bool include_optimizer_state = false); - /** \brief Adds the given property to the checkpoint state. + /** \brief Adds or updates the given property to/in the checkpoint state. * * Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint - * state by the user if they desire by calling this function with the appropriate property name and - * value. The given property name must be unique to be able to successfully add the property. + * state by the user by calling this function with the corresponding property name and value. + * The given property name must be unique to be able to successfully add the property. * - * \param[in] property_name Unique name of the property being added. + * \param[in] property_name Name of the property being added or updated. * \param[in] property_value Property value associated with the given name. * */ @@ -129,12 +129,38 @@ class CheckpointState : public detail::Base { * Gets the property value from an existing entry in the checkpoint state. The property must * exist in the checkpoint state to be able to retrieve it successfully. * - * \param[in] property_name Unique name of the property being retrieved. + * \param[in] property_name Name of the property being retrieved. * \return Property value associated with the given property name. * */ Property GetProperty(const std::string& property_name); + /** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name. + * + * This function updates a model parameter in the checkpoint state with the given parameter data. + * The training session must be already created with the checkpoint state that contains the parameter + * being updated. The given parameter is copied over to the registered device for the training session. + * The parameter must exist in the checkpoint state to be able to update it successfully. + * + * \param[in] parameter_name Name of the parameter being updated. + * \param[in] parameter The parameter data that should replace the existing parameter data. + * + */ + void UpdateParameter(const std::string& parameter_name, const Value& parameter); + + /** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name. + * + * This function retrieves the model parameter data from the checkpoint state for the given parameter name. + * The parameter is copied over to the provided OrtValue. The training session must be already created + * with the checkpoint state that contains the parameter being retrieved. + * The parameter must exist in the checkpoint state to be able to retrieve it successfully. + * + * \param[in] parameter_name Name of the parameter being retrieved. + * \return The parameter data that is retrieved from the checkpoint state. + * + */ + Value GetParameter(const std::string& parameter_name); + /// @} }; @@ -176,6 +202,20 @@ class TrainingSession : public detail::Base { const std::optional>& eval_model_path = std::nullopt, const std::optional>& optimizer_model_path = std::nullopt); + /** \brief Create a training session that can be used to begin or resume training. + * This constructor allows the users to load the models from buffers instead of files. + * + * \param[in] env Env to be used for the training session. + * \param[in] session_options SessionOptions that the user can customize for this training session. + * \param[in] checkpoint_state Training states that the training session uses as a starting point for training. + * \param[in] train_model_data Buffer containing training model data. + * \param[in] eval_model_data Buffer containing evaluation model data. + * \param[in] optim_model_data Buffer containing optimizer model (used for performing weight/parameter update). + * + */ + TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state, + const std::vector& train_model_data, const std::vector& eval_model_data = {}, + const std::vector& optim_model_data = {}); /// @} /// \name Implementing The Training Loop diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h index 393e5b01f7..a5efa3c0e4 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h @@ -24,6 +24,23 @@ inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& se ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_)); } +inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options, + CheckpointState& checkpoint_state, + const std::vector& train_model_data, + const std::vector& eval_model_data, + const std::vector& optim_model_data) { + ThrowOnError(GetTrainingApi().CreateTrainingSessionFromBuffer( + env, session_options, checkpoint_state, + train_model_data.data(), train_model_data.size(), + eval_model_data.data(), eval_model_data.size(), + optim_model_data.data(), optim_model_data.size(), + &p_)); + + ThrowOnError(GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(p_, &training_model_output_count_)); + + ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_)); +} + inline std::vector TrainingSession::TrainStep(const std::vector& input_values) { std::vector output_values; output_values.reserve(training_model_output_count_); @@ -262,4 +279,16 @@ inline Property CheckpointState::GetProperty(const std::string& property_name) { return property; } +inline void CheckpointState::UpdateParameter(const std::string& parameter_name, const Value& parameter) { + ThrowOnError(GetTrainingApi().UpdateParameter(p_, parameter_name.c_str(), parameter)); +} + +inline Value CheckpointState::GetParameter(const std::string& parameter_name) { + AllocatorWithDefaultOptions allocator; + OrtValue* parameter; + ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), allocator, ¶meter)); + + return Value{parameter}; +} + } // namespace Ort diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index 29300bbb7e..cf49a01517 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -12,7 +12,6 @@ #include "core/graph/graph_utils.h" #include "orttraining/training_api/checkpoint.h" -#include "orttraining/training_api/utils.h" using namespace onnxruntime; @@ -120,6 +119,61 @@ Status TransformModelInputsForInference(Graph& inference_graph, #endif } // namespace +Status Parameter::CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const { + ORT_ENFORCE(data.IsAllocated(), "Given parameter data is not allocated. Cannot copy the checkpoint parameter to it."); + ORT_ENFORCE(data.IsTensor(), "Parameter data should be of tensor type."); + ORT_ENFORCE(data.Get().Shape() == data_.Get().Shape(), + "Parameter data shape mismatch. Expected: ", data_.Get().Shape().ToString(), + ", Got: ", data.Get().Shape().ToString()); +#ifdef ENABLE_STRIDED_TENSORS + auto data_strides = data.Get().Strides(); + auto param_strides = data_.Get().Strides(); + ORT_ENFORCE(data_strides.size() == param_strides.size(), + "Parameter data stride mismatch. Expected strides of size: ", param_strides.size(), + ", Got: ", data_strides.size()); + ORT_ENFORCE(std::equal(data_strides.begin(), data_strides.end(), param_strides.begin()), + "Parameter data stride value mismatch."); +#endif + ORT_ENFORCE(data.Get().DataType() == data_.Get().DataType(), + "Parameter data type mismatch. Expected: ", data_.Get().DataType(), + ", Got: ", data.Get().DataType()); + ORT_ENFORCE(data_transfer_manager != nullptr, + "Data transfer manager must be provided to copy data to the parameter. " + "Please create the TrainingSession before trying to update the parameter."); + + ORT_THROW_IF_ERROR(data_transfer_manager->CopyTensor(data_.Get(), *data.GetMutable())); + + return Status::OK(); +} + +Status Parameter::CopyFrom(const DataTransferManager* data_transfer_manager, const OrtValue& data) { + ORT_ENFORCE(data_.IsAllocated(), + "The checkpoint parameter is not allocated. Cannot copy the given parameter data to it."); + ORT_ENFORCE(data.IsTensor(), "Parameter data should be of tensor type."); + ORT_ENFORCE(data.Get().Shape() == data_.Get().Shape(), + "Parameter data shape mismatch. Expected: ", data_.Get().Shape().ToString(), + ", Got: ", data.Get().Shape().ToString()); +#ifdef ENABLE_STRIDED_TENSORS + auto data_strides = data.Get().Strides(); + auto param_strides = data_.Get().Strides(); + ORT_ENFORCE(data_strides.size() == param_strides.size(), + "Parameter data stride mismatch. Expected strides of size: ", param_strides.size(), + ", Got: ", data_strides.size()); + ORT_ENFORCE(std::equal(data_strides.begin(), data_strides.end(), param_strides.begin()), + "Parameter data stride value mismatch."); +#endif + ORT_ENFORCE(data.Get().DataType() == data_.Get().DataType(), + "Parameter data type mismatch. Expected: ", data_.Get().DataType(), + ", Got: ", data.Get().DataType()); + ORT_ENFORCE(data_transfer_manager != nullptr, + "Data transfer manager must be provided to copy data to the parameter. " + "Please create the TrainingSession before trying to update the parameter."); + + ORT_THROW_IF_ERROR(data_transfer_manager->CopyTensor(data.Get(), *data_.GetMutable())); + + return Status::OK(); +} + Status Parameter::SetGrad(const std::string& gradient_name, const OrtValue& param_grad) { // assert param is allocated ORT_ENFORCE(data_.IsAllocated(), "Parameter data should be allocated before allocating gradient."); @@ -150,12 +204,11 @@ Status Parameter::ResetGrad() { return Status::OK(); } -Module::Module(const std::string& train_model_path_or_bytes, +Module::Module(const ModelIdentifiers& model_identifiers, CheckpointState* state, const onnxruntime::SessionOptions& session_options, const Environment& env, const std::vector>& providers, - const std::optional& eval_model_path_or_bytes, [[maybe_unused]] gsl::span op_domains) : state_{state} { // Enforce weight prepacking is disabled @@ -176,7 +229,12 @@ Module::Module(const std::string& train_model_path_or_bytes, } #endif - ORT_THROW_IF_ERROR(train_sess_->Load(train_model_path_or_bytes)); + // Load the training model + ORT_THROW_IF_ERROR(std::holds_alternative(model_identifiers.train_model) + ? train_sess_->Load(std::get(model_identifiers.train_model)) + : train_sess_->Load(std::get>(model_identifiers.train_model).data(), + static_cast(std::get>(model_identifiers.train_model).size()))); + for (const auto& provider : providers) { ORT_THROW_IF_ERROR(train_sess_->RegisterExecutionProvider(provider)); } @@ -239,7 +297,6 @@ Module::Module(const std::string& train_model_path_or_bytes, // Copy ortvalue buffer from CPU to target_device for this "param_name" (based on graph partitioning) // Only copies data if the target device is not the same as the current device the buffer is placed on - OrtValue& param_data = params_iter->second->Data(); ORT_ENFORCE(param_data.IsTensor()); const Tensor& param_data_tensor = param_data.Get(); @@ -278,48 +335,62 @@ Module::Module(const std::string& train_model_path_or_bytes, } } - if (eval_model_path_or_bytes.has_value()) { + if (model_identifiers.IsEvalModelAvailable()) { eval_sess_ = std::make_unique(session_options, env); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) if (!op_domains.empty()) { ORT_THROW_IF_ERROR(eval_sess_->AddCustomOpDomains(op_domains)); } #endif - - ORT_THROW_IF_ERROR(eval_sess_->Load(eval_model_path_or_bytes.value())); - for (const auto& provider : providers) { - ORT_THROW_IF_ERROR(eval_sess_->RegisterExecutionProvider(provider)); - } - ORT_THROW_IF_ERROR(eval_sess_->Initialize()); - utils::GetGraphInputOutputNames(eval_sess_, eval_input_names_, eval_output_names_); - - // Eval model validation - // We are making certain assumptions: Like the order in which parameters occur will be same between train and eval - // graphs, and all the weights present in both graphs match. - // TODO: Add the checks instead of making assumptions?? - InlinedVector eval_user_input_names, eval_param_input_names; - for (const auto& input_name : eval_input_names_) { - if (state_->module_checkpoint_state.named_parameters.find(input_name) != - state_->module_checkpoint_state.named_parameters.end()) { - // it is a parameter - eval_param_input_names.emplace_back(input_name); - continue; - } else { - // It is user input. We handle user inputs separately in the eval - // because the eval graph might have different user inputs. - // Eg if loss is not a part of the eval graph, it won't have - // certain inputs like targets - eval_user_input_names.emplace_back(input_name); - } + if (std::holds_alternative>(model_identifiers.eval_model)) { + ORT_THROW_IF_ERROR(eval_sess_->Load(std::get>(model_identifiers.eval_model).value())); + } else { + auto model_data = std::get>(model_identifiers.eval_model); + ORT_THROW_IF_ERROR(eval_sess_->Load(model_data.data(), static_cast(model_data.size()))); } - eval_input_names_ = eval_user_input_names; - eval_user_input_count_ = eval_user_input_names.size(); - eval_input_names_.insert(eval_input_names_.end(), eval_param_input_names.begin(), eval_param_input_names.end()); + } else { + return; + } - // Keep a copy of the eval model path to be able to later export the model for inferencing. - // The inference model will be reconstructed from the eval model. - eval_model_path_ = eval_model_path_or_bytes.value(); + for (const auto& provider : providers) { + ORT_THROW_IF_ERROR(eval_sess_->RegisterExecutionProvider(provider)); } + ORT_THROW_IF_ERROR(eval_sess_->Initialize()); + utils::GetGraphInputOutputNames(eval_sess_, eval_input_names_, eval_output_names_); + + // Eval model validation + // We are making certain assumptions: Like the order in which parameters occur will be same between train and eval + // graphs, and all the weights present in both graphs match. + // TODO(askhade): Add the checks instead of making assumptions?? + InlinedVector eval_user_input_names, eval_param_input_names; + for (const auto& input_name : eval_input_names_) { + if (state_->module_checkpoint_state.named_parameters.find(input_name) != + state_->module_checkpoint_state.named_parameters.end()) { + // it is a parameter + eval_param_input_names.emplace_back(input_name); + continue; + } else { + // It is user input. We handle user inputs separately in the eval + // because the eval graph might have different user inputs. + // Eg if loss is not a part of the eval graph, it won't have + // certain inputs like targets + eval_user_input_names.emplace_back(input_name); + } + } + eval_input_names_ = eval_user_input_names; + eval_user_input_count_ = eval_user_input_names.size(); + eval_input_names_.insert(eval_input_names_.end(), eval_param_input_names.begin(), eval_param_input_names.end()); + + // Keep a copy of the eval model path to be able to later export the model for inferencing. + // The inference model will be reconstructed from the eval model. + // TODO(askhade): Find a fix to export model for inference when the eval model is loaded from a buffer. + if (std::holds_alternative>(model_identifiers.eval_model)) { + eval_model_path_ = std::get>(model_identifiers.eval_model); + } +} + +Module::~Module() { + state_->module_checkpoint_state.train_session_data_transfer_mgr = nullptr; } size_t Module::GetTrainingModelOutputCount() const noexcept { @@ -486,14 +557,14 @@ Status Module::EvalStep(const std::vector& inputs, std::vector graph_output_names) const { - ORT_RETURN_IF(!eval_sess_ || eval_model_path_.empty(), + ORT_RETURN_IF(!eval_sess_ || !eval_model_path_.has_value(), "Eval model was not provided. Cannot export a model for inferencing."); ONNX_NAMESPACE::ModelProto eval_model; - ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_), eval_model)); + ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_.value()), eval_model)); // Clone the eval mode into an inference onnxruntime::Model. std::shared_ptr inference_model; diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index 9013ab22c1..f323e6be72 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -3,7 +3,9 @@ #pragma once +#include #include "core/session/inference_session.h" +#include "orttraining/training_api/utils.h" namespace onnxruntime { namespace training { @@ -19,6 +21,8 @@ struct Parameter { // Return the mutable data. OrtValue& Data() { return data_; } + Status CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const; + Status CopyFrom(const DataTransferManager* data_transfer_manager, const OrtValue& data); const std::string& Name() const { return name_; } // Returns whether this parameter is trainable or not. @@ -32,7 +36,6 @@ struct Parameter { // Reset and release the gradient buffer of this Parameter greedily. Status ResetGrad(); - protected: Status SetGrad(const std::string& gradient_name, const OrtValue& param_grad); private: @@ -73,14 +76,16 @@ struct Module { public: // Initialize a module from an ORT inference session with loaded // training ONNX model and load parameters - Module(const std::string& train_model_path_or_bytes, + // The model and checkpoint state can be provided as a file path or a byte array + Module(const ModelIdentifiers& model_identifiers, CheckpointState* state, const onnxruntime::SessionOptions& session_options, const Environment& env, const std::vector>& providers, - const std::optional& eval_model_path_or_bytes = std::nullopt, gsl::span op_domains = gsl::span()); + ~Module(); + // Return the trainable/nontrainable parameters std::vector> Parameters() const; @@ -159,7 +164,7 @@ struct Module { CheckpointState* state_; // Non owning pointer to the state. bool accumulate_gradient_ = false; - std::string eval_model_path_; + std::optional eval_model_path_; size_t train_user_input_count_{0U}; size_t eval_user_input_count_{0U}; }; diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index b84009e7f3..38a9aad964 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -13,6 +13,8 @@ #include "orttraining/training_api/ort_training_apis.h" #include "orttraining/training_api/training_session.h" +using namespace onnxruntime::training::api; + namespace { std::vector> CreateProviders( @@ -26,44 +28,85 @@ std::vector> CreateProviders( return execution_providers; } +static OrtStatus* CreateSessionAndLoadModel(_In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, + _Inout_ OrtCheckpointState* checkpoint_state, + const ModelIdentifiers& model_identifiers, + std::unique_ptr& train_sess) { + auto chkpt_state = reinterpret_cast(checkpoint_state); + + using ProvidersType = std::vector>; + train_sess = std::make_unique(env->GetEnvironment(), + options == nullptr ? onnxruntime::SessionOptions() : options->value, + options == nullptr + ? ProvidersType() + : CreateProviders(options->provider_factories), + chkpt_state, + model_identifiers, + options == nullptr + ? gsl::span() + : options->custom_op_domains_); + + return nullptr; +} + } // namespace ORT_API_STATUS_IMPL(OrtTrainingApis::CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path, _In_ const ORTCHAR_T* eval_model_path, - _In_ const ORTCHAR_T* optimizer_model_path, _Outptr_ OrtTrainingSession** out) { + _In_ const ORTCHAR_T* optimizer_model_path, _Outptr_result_maybenull_ OrtTrainingSession** out) { API_IMPL_BEGIN if (options != nullptr && options->value.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigUseEnvAllocators, "0") == "1") { return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Use Env Allocators is not supported for on device training."); } std::unique_ptr train_sess; - auto chkpt_state = reinterpret_cast(checkpoint_state); OrtStatus* status = nullptr; *out = nullptr; - ORT_TRY { - using ProvidersType = std::vector>; - train_sess = std::make_unique( - env->GetEnvironment(), - options == nullptr ? onnxruntime::SessionOptions() : options->value, - options == nullptr ? ProvidersType() : CreateProviders(options->provider_factories), - chkpt_state, - onnxruntime::training::api::ModelIdentifiers( - onnxruntime::ToUTF8String(train_model_path), - eval_model_path ? std::optional(onnxruntime::ToUTF8String(eval_model_path)) - : std::nullopt, - optimizer_model_path ? std::optional(onnxruntime::ToUTF8String(optimizer_model_path)) - : std::nullopt), - options == nullptr ? gsl::span() : options->custom_op_domains_); - - *out = reinterpret_cast(train_sess.release()); - } - ORT_CATCH(const std::exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - status = OrtApis::CreateStatus(ORT_FAIL, e.what()); - }); - } + ORT_ENFORCE(train_model_path != nullptr, + "Train model path is required to create TrainingSession, it cannot be empty."); + + auto model_identifiers = onnxruntime::training::api::ModelIdentifiers( + onnxruntime::ToUTF8String(train_model_path), + eval_model_path ? std::optional(onnxruntime::ToUTF8String(eval_model_path)) + : std::nullopt, + optimizer_model_path ? std::optional(onnxruntime::ToUTF8String(optimizer_model_path)) + : std::nullopt); + + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(env, options, checkpoint_state, model_identifiers, train_sess)); + *out = reinterpret_cast(train_sess.release()); + + return status; + API_IMPL_END +} +ORT_API_STATUS_IMPL(OrtTrainingApis::CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env, + _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const void* train_model_data, size_t train_data_length, + _In_ const void* eval_model_data, size_t eval_data_length, + _In_ const void* optim_model_data, size_t optim_data_length, + _Outptr_result_maybenull_ OrtTrainingSession** out) { + API_IMPL_BEGIN + std::unique_ptr train_sess; + OrtStatus* status = nullptr; + *out = nullptr; + + ORT_ENFORCE(train_model_data != nullptr && train_data_length != 0, + "Training Session Creation failed. Train model data cannot be NULL."); + + auto model_identifiers = ModelIdentifiers(gsl::make_span(reinterpret_cast(train_model_data), + train_data_length), + eval_data_length == 0 || eval_model_data == nullptr + ? gsl::span() + : gsl::make_span(reinterpret_cast(eval_model_data), + eval_data_length), + optim_data_length == 0 || optim_model_data == nullptr + ? gsl::span() + : gsl::make_span(reinterpret_cast(optim_model_data), + optim_data_length)); + + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(env, options, checkpoint_state, model_identifiers, train_sess)); + *out = reinterpret_cast(train_sess.release()); return status; API_IMPL_END } @@ -290,6 +333,10 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::LoadCheckpointFromBuffer, _In_ const void* _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state) { API_IMPL_BEGIN + if (checkpoint_buffer == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid checkpoint buffer. Actual: nullptr."); + } + *checkpoint_state = nullptr; auto chkpt_state = std::make_unique(); const auto* checkpoint_bytes = reinterpret_cast(checkpoint_buffer); @@ -516,6 +563,76 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::GetProperty, _In_ const OrtCheckpointState* API_IMPL_END } +ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape) { + API_IMPL_BEGIN + + auto chkpt_state = reinterpret_cast(checkpoint_state); + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { + std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + + return OrtApis::GetTensorTypeAndShape(&it->second->Data(), parameter_type_and_shape); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtTrainingApis::UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _In_ OrtValue* parameter) { + API_IMPL_BEGIN + if (parameter == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid parameter. Actual: nullptr."); + } + + auto chkpt_state = reinterpret_cast(checkpoint_state); + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { + std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + ORT_API_RETURN_IF_STATUS_NOT_OK(it->second->CopyFrom( + chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, *parameter)); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameter, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** parameter) { + API_IMPL_BEGIN + + if (parameter == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid parameter. Actual: nullptr."); + } + + auto chkpt_state = reinterpret_cast(checkpoint_state); + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { + std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + + if (!it->second->Data().IsTensor()) { + return OrtApis::CreateStatus(ORT_FAIL, "Expected a tensor type for the parameter. Found a non-tensor type."); + } + const auto& parameter_tensor = it->second->Data().Get(); + ORT_API_RETURN_IF_ERROR(OrtApis::CreateTensorAsOrtValue( + allocator, parameter_tensor.Shape().GetDims().data(), parameter_tensor.Shape().NumDimensions(), + ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, parameter)); + + auto status = it->second->CopyTo( + chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, **parameter); + if (!status.IsOK()) { + OrtApis::ReleaseValue(*parameter); + return onnxruntime::ToOrtStatus(status); + } + + return nullptr; + API_IMPL_END +} + static constexpr OrtTrainingApi ort_training_api = { // NOTE: The C# bindings depend on the API order within this struct. Since Training APIs are not officially // released, it is OK to change the order here, however a corresponding matching change should also be done in the @@ -523,6 +640,7 @@ static constexpr OrtTrainingApi ort_training_api = { &OrtTrainingApis::LoadCheckpoint, &OrtTrainingApis::SaveCheckpoint, &OrtTrainingApis::CreateTrainingSession, + &OrtTrainingApis::CreateTrainingSessionFromBuffer, &OrtTrainingApis::TrainingSessionGetTrainingModelOutputCount, &OrtTrainingApis::TrainingSessionGetEvalModelOutputCount, &OrtTrainingApis::TrainingSessionGetTrainingModelOutputName, @@ -548,7 +666,10 @@ static constexpr OrtTrainingApi ort_training_api = { &OrtTrainingApis::TrainingSessionGetEvalModelInputName, &OrtTrainingApis::AddProperty, &OrtTrainingApis::GetProperty, - &OrtTrainingApis::LoadCheckpointFromBuffer}; + &OrtTrainingApis::LoadCheckpointFromBuffer, + &OrtTrainingApis::GetParameterTypeAndShape, + &OrtTrainingApis::UpdateParameter, + &OrtTrainingApis::GetParameter}; ORT_API(const OrtTrainingApi*, OrtTrainingApis::GetTrainingApi, uint32_t) { // No constraints on the API version yet. diff --git a/orttraining/orttraining/training_api/optimizer.cc b/orttraining/orttraining/training_api/optimizer.cc index a6b82f1d50..7f583ce8f6 100644 --- a/orttraining/orttraining/training_api/optimizer.cc +++ b/orttraining/orttraining/training_api/optimizer.cc @@ -61,19 +61,10 @@ Status GraphInputsAreExpected(gsl::span actual_graph_inputs, } // namespace std::unique_ptr OptimizerAlorithmFactory::CreateInstance( - const std::string& optim_path, int32_t& group_count) { + std::shared_ptr model, int32_t& group_count) { std::map, int32_t> opt_type_to_freq_map; #if !defined(ORT_MINIMAL_BUILD) - if (const auto optim_path_str = ToPathString(optim_path); - fbs::utils::IsOrtFormatModel(optim_path_str)) { - // TODO (baijumeswani): Figure out the best way to extract the optimizer type - // from an ort format model. - opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1; - } else { - std::shared_ptr model; - ORT_ENFORCE(Model::Load(optim_path_str, model, nullptr, - logging::LoggingManager::DefaultLogger()) - .IsOK()); + if (model != nullptr) { Graph& graph = model->MainGraph(); for (auto& node : graph.Nodes()) { if (node.Domain() == kMSDomain && (node.OpType() == "AdamWOptimizer" || node.OpType() == "SGDOptimizerV2")) { @@ -85,33 +76,71 @@ std::unique_ptr OptimizerAlorithmFactory::CreateInstance opt_type_to_freq_map[domain_type_pair] += 1; } } - } + } else { #else - // TODO (baijumeswani): Figure out the best way to extract the optimizer type - // from the model (either onnx model or ort format model) or from the checkpoint. - // For now, assume that the optimizer type is AdamWOptimizer in a minimal build. - ORT_UNUSED_PARAMETER(optim_path); - - opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1; + ORT_UNUSED_PARAMETER(model); +#endif + // TODO(baijumeswani): Figure out the best way to extract the optimizer type + // from the model (either onnx model or ort format model) or from the checkpoint. + // For now, assume that the optimizer type is AdamWOptimizer when using ort format models. + opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1; +#if !defined(ORT_MINIMAL_BUILD) + } #endif ORT_ENFORCE(opt_type_to_freq_map.size() == 1U, "Only support one type of optimizer algorithm, but got: " + std::to_string(opt_type_to_freq_map.size())); auto opt_it = opt_type_to_freq_map.begin(); + auto& op_type = opt_it->first.second; group_count = opt_it->second; - auto& domain = opt_it->first.first; - auto& type = opt_it->first.second; + ORT_ENFORCE(group_count == 1, "Group count can only be 1, but got: " + std::to_string(group_count)); // TODO: to support multiple groups, need to create a mapping between each group to its parameter list. - if (domain == kMSDomain && type == "AdamWOptimizer") { + if (op_type == "AdamWOptimizer") { return std::make_unique(); - } else if (domain == kMSDomain && type == "SGDOptimizerV2") { + } else if (op_type == "SGDOptimizerV2") { return std::make_unique(); } else { ORT_NOT_IMPLEMENTED("Not implemented for optimizer algo: " + opt_it->first.second); } } +std::unique_ptr OptimizerAlorithmFactory::CreateInstance( + const PathString& optim_path, int32_t& group_count) { + std::shared_ptr model = nullptr; +#if !defined(ORT_MINIMAL_BUILD) + if (!fbs::utils::IsOrtFormatModel(optim_path)) { + ORT_ENFORCE(Model::Load(optim_path, model, nullptr, + logging::LoggingManager::DefaultLogger()) + .IsOK()); + } +#else + ORT_UNUSED_PARAMETER(optim_path); +#endif + return CreateInstance(model, group_count); +} + +std::unique_ptr OptimizerAlorithmFactory::CreateInstance( + const uint8_t* optim_model_data, size_t optim_model_data_len, int32_t& group_count) { + std::shared_ptr model = nullptr; +#if !defined(ORT_MINIMAL_BUILD) + if (!fbs::utils::IsOrtFormatModelBytes(optim_model_data, static_cast(optim_model_data_len))) { + ONNX_NAMESPACE::ModelProto model_proto; + ORT_ENFORCE(model_proto.ParseFromArray(optim_model_data, static_cast(optim_model_data_len)) == true, + "Failed to load model because protobuf parsing failed."); + + ORT_ENFORCE(Model::Load(std::move(model_proto), model, nullptr, + logging::LoggingManager::DefaultLogger(), ModelOptions(true, true)) + .IsOK()); + } +#else + ORT_UNUSED_PARAMETER(optim_model_data); + ORT_UNUSED_PARAMETER(optim_model_data_len); +#endif + + return CreateInstance(model, group_count); +} + Status Optimizer::GenerateMomentumNamedStates(OptimizerCheckpointState& optimizer_checkpoint_states) { auto group_optimizer_state_it = optimizer_checkpoint_states.group_named_optimizer_states.find(GROUP_ZERO_NAME); @@ -200,14 +229,14 @@ Status Optimizer::ConstructInputs() { return Status::OK(); } // namespace api -Optimizer::Optimizer(const std::string& optim_path_or_bytes, +Optimizer::Optimizer(const ModelIdentifiers& model_identifiers, CheckpointState* state, const onnxruntime::SessionOptions& session_options, const Environment& env, const std::vector>& providers, gsl::span op_domains) : optim_sess_(std::make_unique(session_options, env)), state_(state) { - Initialize(optim_path_or_bytes, providers, op_domains); + Initialize(model_identifiers, providers, op_domains); ORT_ENFORCE(state != nullptr, "Checkpoint state cannot be null."); auto g_it = state_->optimizer_checkpoint_state.group_named_optimizer_states.find(GROUP_ZERO_NAME); @@ -223,7 +252,7 @@ Optimizer::Optimizer(const std::string& optim_path_or_bytes, } } -void Optimizer::Initialize(const std::string& optim_path_or_bytes, +void Optimizer::Initialize(const ModelIdentifiers& model_identifiers, const std::vector>& providers, [[maybe_unused]] gsl::span op_domains) { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) @@ -236,7 +265,22 @@ void Optimizer::Initialize(const std::string& optim_path_or_bytes, ORT_THROW_IF_ERROR(optim_sess_->RegisterExecutionProvider(execution_provider)); } - ORT_THROW_IF_ERROR(optim_sess_->Load(optim_path_or_bytes)); + ORT_ENFORCE(model_identifiers.IsOptimizerModelAvailable(), "Optimizer model is not available."); + + if (std::holds_alternative>(model_identifiers.optim_model)) { + auto optimizer_model = std::get>(model_identifiers.optim_model); + // The above call to IsOptimizerModelAvailable() ensures that optimizer_model is not nullopt + ORT_THROW_IF_ERROR(optim_sess_->Load(optimizer_model.value())); + optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(ToWideString(optimizer_model.value()), group_count_); + } else { + auto optimizer_model = std::get>(model_identifiers.optim_model); + ORT_THROW_IF_ERROR(optim_sess_->Load(optimizer_model.data(), + static_cast(optimizer_model.size()))); + optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(optimizer_model.data(), + optimizer_model.size(), + group_count_); + } + ORT_THROW_IF_ERROR(optim_sess_->Initialize()); // Make sure that the checkpoint state can copy tensors @@ -244,10 +288,6 @@ void Optimizer::Initialize(const std::string& optim_path_or_bytes, utils::GetGraphInputOutputNames(optim_sess_, input_names_, output_names_); - optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(optim_path_or_bytes, group_count_); - ORT_ENFORCE(group_count_ == 1, "Group count can only be 1, but got: " + std::to_string(group_count_)); - ORT_ENFORCE(optimizer_algo_ptr_, "optimizer_algo_ptr_ should not be nullptr."); - InlinedVector all_input_names; all_input_names.reserve(CommonOptimizerInputs.size() + optimizer_algo_ptr_->optimizer_states_inputs.size()); all_input_names.insert(all_input_names.end(), CommonOptimizerInputs.begin(), diff --git a/orttraining/orttraining/training_api/optimizer.h b/orttraining/orttraining/training_api/optimizer.h index 36ce3297fe..d9bc4870bb 100644 --- a/orttraining/orttraining/training_api/optimizer.h +++ b/orttraining/orttraining/training_api/optimizer.h @@ -64,8 +64,11 @@ struct SGDOptimizerV2Algorithm : public OptimizerAlgorithmBase { }; struct OptimizerAlorithmFactory { - static std::unique_ptr CreateInstance(const std::string& optim_path_or_bytes, + static std::unique_ptr CreateInstance(const PathString& optim_path, int32_t& group_count); + static std::unique_ptr CreateInstance(const uint8_t* optim_model_data, + size_t optim_model_data_len, int32_t& group_count); + static std::unique_ptr CreateInstance(std::shared_ptr model, int32_t& group_count); }; struct CheckpointState; @@ -96,7 +99,7 @@ struct Optimizer { // Initialize an optimizer module from an ORT inference session with loaded // training ONNX model For each parameter, initialize the OptimizerState based // on the graph input's ValueInfoProto if the parameter doesn't have it already. - Optimizer(const std::string& optim_path_or_bytes, + Optimizer(const ModelIdentifiers& model_identifiers, CheckpointState* state, const onnxruntime::SessionOptions& session_options, const Environment& env, @@ -121,7 +124,7 @@ struct Optimizer { } private: - void Initialize(const std::string& optim_path_or_bytes, + void Initialize(const ModelIdentifiers& model_identifiers, const std::vector>& providers, gsl::span op_domains); diff --git a/orttraining/orttraining/training_api/ort_training_apis.h b/orttraining/orttraining/training_api/ort_training_apis.h index 2b383f3b97..2a8c1e3036 100644 --- a/orttraining/orttraining/training_api/ort_training_apis.h +++ b/orttraining/orttraining/training_api/ort_training_apis.h @@ -8,7 +8,14 @@ ORT_API(const OrtTrainingApi*, GetTrainingApi, uint32_t version); ORT_API_STATUS_IMPL(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path, _In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path, - _Outptr_ OrtTrainingSession** out); + _Outptr_result_maybenull_ OrtTrainingSession** out); + +ORT_API_STATUS_IMPL(CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env, + _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const void* train_model_data, size_t train_data_length, + _In_ const void* eval_model_data, size_t eval_data_length, + _In_ const void* optim_model_data, size_t optim_data_length, + _Outptr_result_maybenull_ OrtTrainingSession** out); ORT_API_STATUS_IMPL(TrainingSessionGetTrainingModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); @@ -87,4 +94,14 @@ ORT_API_STATUS_IMPL(GetProperty, _In_ const OrtCheckpointState* checkpoint_state ORT_API_STATUS_IMPL(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer, _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state); +ORT_API_STATUS_IMPL(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape); + +ORT_API_STATUS_IMPL(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _In_ OrtValue* parameter); + +ORT_API_STATUS_IMPL(GetParameter, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** parameter); + } // namespace OrtTrainingApis diff --git a/orttraining/orttraining/training_api/training_session.cc b/orttraining/orttraining/training_api/training_session.cc index 6915193a8f..45f0f0ddcf 100644 --- a/orttraining/orttraining/training_api/training_session.cc +++ b/orttraining/orttraining/training_api/training_session.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "orttraining/training_api/training_session.h" +#include "orttraining/training_api/utils.h" namespace onnxruntime::training::api { @@ -12,13 +13,12 @@ TrainingSession::TrainingSession(const Environment& session_env, const ModelIdentifiers& model_identifiers, gsl::span custom_op_domains) : state_{state}, - module_{std::make_unique(model_identifiers.train_model, state_, - session_options, session_env, providers, - model_identifiers.eval_model, custom_op_domains)}, - optimizer_{model_identifiers.optim_model.has_value() + module_{std::make_unique(model_identifiers, state_, + session_options, session_env, providers, custom_op_domains)}, + optimizer_{model_identifiers.IsOptimizerModelAvailable() ? std::make_unique( - model_identifiers.optim_model.value(), state_, - session_options, session_env, providers, custom_op_domains) + model_identifiers, state_, + session_options, session_env, providers) : std::unique_ptr()} {} Status TrainingSession::RegisterScheduler( diff --git a/orttraining/orttraining/training_api/training_session.h b/orttraining/orttraining/training_api/training_session.h index 1a16acd511..13b0ae7909 100644 --- a/orttraining/orttraining/training_api/training_session.h +++ b/orttraining/orttraining/training_api/training_session.h @@ -3,25 +3,17 @@ #pragma once #include "core/common/common.h" -#include "module.h" -#include "optimizer.h" -#include "lr_scheduler.h" -#include "checkpoint.h" +#include "orttraining/training_api/module.h" +#include "orttraining/training_api/optimizer.h" +#include "orttraining/training_api/lr_scheduler.h" +#include "orttraining/training_api/checkpoint.h" +#include "orttraining/training_api/utils.h" namespace onnxruntime { namespace training { namespace api { using namespace common; -struct ModelIdentifiers { - const std::string train_model; - const std::optional eval_model, optim_model; - ModelIdentifiers(const std::string& train_model_uri, - const std::optional& eval_model_uri, - const std::optional& optim_model_uri) - : train_model(train_model_uri), eval_model(eval_model_uri), optim_model(optim_model_uri) {} -}; - // Wrapper on top of module and optimizer classes and is the only class exposed via capis class TrainingSession { public: diff --git a/orttraining/orttraining/training_api/utils.h b/orttraining/orttraining/training_api/utils.h index e856554c97..f16f0f947f 100644 --- a/orttraining/orttraining/training_api/utils.h +++ b/orttraining/orttraining/training_api/utils.h @@ -10,6 +10,40 @@ namespace onnxruntime { namespace training { namespace api { + +struct ModelIdentifiers { + // ModelIdentifiers struct enables an easy way to store and identify the models used for training, evaluation + // and model updates(optimizer model). + // The model can be specified by a path to the model file or by a span of bytes containing the model data. + // Training model is required, evaluation and optimizer models are optional. + std::variant> train_model; + std::variant, gsl::span> eval_model; + std::variant, gsl::span> optim_model; + + ModelIdentifiers(std::variant> training_model, + std::variant, gsl::span> evaluation_model, + std::variant, gsl::span> optimzer_model) + : train_model(training_model), eval_model(evaluation_model), optim_model(optimzer_model) {} + + bool IsModelAvailable(const std::variant, gsl::span>& model) const { + if ((std::holds_alternative>(model) && + std::get>(model).has_value()) || + (std::holds_alternative>(model) && + std::get>(model).size() > 0)) { + return true; + } + return false; + } + + bool IsEvalModelAvailable() const { + return IsModelAvailable(eval_model); + } + + bool IsOptimizerModelAvailable() const { + return IsModelAvailable(optim_model); + } +}; + namespace utils { // Get names of graph inputs and outputs diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index 6aac9ad7ec..1136bff519 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -85,6 +85,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ConvGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ConvGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ConvGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ConvTransposeGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ConvTransposeGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ConvTransposeGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GatherGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DropoutGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BitmaskDropoutGrad); @@ -202,6 +205,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inpl class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuant); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuantGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PadAndUnflatten); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ResizeGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ResizeGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ResizeGrad); // the kernels within the following ifdef are not included in a build with // --enable_training_ops but without --enable_training @@ -346,6 +352,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -450,6 +459,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // the kernels within the following ifdef are not included in a build with // --enable_training_ops but without --enable_training #ifdef ENABLE_TRAINING diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc index f69da000be..f6c58445c0 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc @@ -3,13 +3,6 @@ #include "orttraining/training_ops/cuda/nn/conv_grad.h" -#include "core/providers/common.h" -#include "core/providers/cuda/shared_inc/fpgeneric.h" -#include "core/platform/ort_mutex.h" - -// The AlgoPerfCache and AlgoSearch here for Conv/ConvGrad is referenced on PyTorch's implementation -// from aten/src/ATen/native/cudnn/Conv_v7.cpp. - namespace onnxruntime { namespace cuda { @@ -22,229 +15,6 @@ REGISTER_GRADIENT_KERNEL_TYPED(float) REGISTER_GRADIENT_KERNEL_TYPED(double) REGISTER_GRADIENT_KERNEL_TYPED(MLFloat16) -using T_BwdDataPerf = cudnnConvolutionBwdDataAlgoPerf_t; -using T_BwdDataAlgo = cudnnConvolutionBwdDataAlgo_t; -using T_BwdFilterPerf = cudnnConvolutionBwdFilterAlgoPerf_t; -using T_BwdFilterAlgo = cudnnConvolutionBwdFilterAlgo_t; - -cudnnStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdDataAlgo algo, size_t* workspace_size) { - return cudnnGetConvolutionBackwardDataWorkspaceSize(args.handle, args.w_desc, args.y_tensor, args.conv_desc, - args.x_tensor, algo, workspace_size); -} - -cudnnStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdFilterAlgo algo, size_t* workspace_size) { - return cudnnGetConvolutionBackwardFilterWorkspaceSize(args.handle, args.x_tensor, args.y_tensor, args.conv_desc, - args.w_desc, algo, workspace_size); -} - -template -size_t GetMaxWorkspaceSize(const ConvArgs& args, const T_Algo* algo, int n_algo) { - // Calling cudaMemGetInfo is not ideal, but our cuda allocator doesn't have a way to get this info. - size_t free, total; - CUDA_CALL_THROW(cudaMemGetInfo(&free, &total)); - // Assuming 10% of fragmentation. - free = static_cast(static_cast(free) * 0.9); - size_t max_workspace_size = 0; - for (int i = 0; i < n_algo; i++) { - cudnnStatus_t status; - size_t workspace_size; - status = GetWorkspaceSize(args, algo[i], &workspace_size); - if (CUDNN_STATUS_SUCCESS != status || workspace_size == 0 || workspace_size < max_workspace_size || - workspace_size > free) - continue; - max_workspace_size = workspace_size; - } - - return max_workspace_size; -} - -template -std::vector GetValidAlgorithms(const T_Perf* perf_results, int n_algo) { - std::vector result; - result.reserve(n_algo); - for (int i = 0; i < n_algo; i++) { - T_Perf perf = perf_results[i]; - if (perf.status == CUDNN_STATUS_SUCCESS) { - result.emplace_back(perf); - } - } - ORT_ENFORCE(result.size() > 0, "No valid convolution algorithms available in CuDNN"); - // TODO: This is a cuDNN bug that gave wrong results in certain strided convolution gradient setups - // when cuDNN version < 7.5. Need to add handling for such special case. - return result; -} - -struct ConvParamsHash { - // ConvParams must be a POD because we read out its memory constant as char* when hashing. - static_assert(std::is_pod::value, "ConvParams is not POD"); - size_t operator()(const ConvParams& conv_params) const { - auto ptr = reinterpret_cast(&conv_params); - uint32_t value = 0x811C9DC5; - for (int i = 0; i < static_cast(sizeof(ConvParams)); ++i) { - value ^= ptr[i]; - value *= 0x01000193; - } - return static_cast(value); - } -}; - -struct ConvParamsEqual { - // ConvParams must be a POD because we read out its memory constant as char* when hashing. - static_assert(std::is_pod::value, "ConvParams is not POD"); - bool operator()(const ConvParams& a, const ConvParams& b) const { - auto ptr1 = reinterpret_cast(&a); - auto ptr2 = reinterpret_cast(&b); - return memcmp(ptr1, ptr2, sizeof(ConvParams)) == 0; - } -}; - -template -struct AlgoPerfCache { - mutable OrtMutex mutex; - std::unordered_map map; - - bool Find(const ConvParams& params, T_Perf* result) { - std::lock_guard guard(mutex); - auto it = map.find(params); - if (it == map.end()) { - return false; - } - *result = it->second; - return true; - } - - void Insert(const ConvParams& params, const T_Perf& algo_perf) { - std::lock_guard guard(mutex); - map[params] = algo_perf; - } -}; - -// TODO: Currently we use global AlgoPerfCache for ConvGrad only. Conv's perf cache is till per node. -// Need to apply such global cache for Conv, and move some shared code from here to conv.h/cc. -AlgoPerfCache bwd_data_algos; -AlgoPerfCache bwd_filter_algos; - -template -struct AlgoSearch {}; - -template <> -struct AlgoSearch { - static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; - static AlgoPerfCache& Cache() { return bwd_data_algos; } - static Status FindAlgorithms(const ConvArgs& args, const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, - std::vector& perf_results) { - static const T_BwdDataAlgo algos[] = { - CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, - CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, - CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED}; - static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; - ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward data algorithms."); - int perf_count; - std::unique_ptr candidates = std::make_unique(num_algos); - if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { - CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionBackwardDataAlgorithm_v7(args.handle, args.w_desc, args.y_tensor, - args.conv_desc, args.x_tensor, num_algos, - &perf_count, candidates.get())); - } else if (args.params.algo_mode == OrtCudnnConvAlgoSearchExhaustive) { - size_t max_workspace_size = provider->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos) - : AlgoSearchWorkspaceSize; - // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. - // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. - IAllocatorUniquePtr workspace = max_workspace_size == 0 ? nullptr : IAllocator::MakeUniquePtr(allocator, max_workspace_size, true); - CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardDataAlgorithmEx( - args.handle, args.w_desc, args.w_data, args.y_tensor, args.dy_data, args.conv_desc, args.x_tensor, - args.dx_data, num_algos, &perf_count, candidates.get(), workspace.get(), max_workspace_size)); - } else { - ORT_ENFORCE(false, "Algo mode should be EXHAUSTIVE (0) or HEURISTIC (1), but got ", args.params.algo_mode); - } - perf_results = GetValidAlgorithms(candidates.get(), perf_count); - return Status::OK(); - } -}; - -template <> -struct AlgoSearch { - static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; - static AlgoPerfCache& Cache() { return bwd_filter_algos; } - static Status FindAlgorithms(const ConvArgs& args, const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, - std::vector& perf_results) { - static const T_BwdFilterAlgo algos[] = { - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, - }; - - // NOTE: - 1 because ALGO_WINOGRAD is not implemented. - static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT - 1; - ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms."); - std::unique_ptr candidates = std::make_unique(num_algos); - int perf_count; - if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { - CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm_v7(args.handle, args.x_tensor, args.y_tensor, - args.conv_desc, args.w_desc, num_algos, - &perf_count, candidates.get())); - } else if (args.params.algo_mode == OrtCudnnConvAlgoSearchExhaustive) { - size_t max_workspace_size = provider->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos) - : AlgoSearchWorkspaceSize; - // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. - // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. - IAllocatorUniquePtr workspace = max_workspace_size == 0 ? nullptr : IAllocator::MakeUniquePtr(allocator, max_workspace_size, true); - CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardFilterAlgorithmEx( - args.handle, args.x_tensor, args.x_data, args.y_tensor, args.dy_data, args.conv_desc, args.w_desc, - args.dw_data, num_algos, &perf_count, candidates.get(), workspace.get(), max_workspace_size)); - } else { - ORT_ENFORCE(false, "Algo mode should be EXHAUSTIVE (0) or HEURISTIC (1), but got ", args.params.algo_mode); - } - perf_results = GetValidAlgorithms(candidates.get(), perf_count); - return Status::OK(); - } -}; - -template -class AlgoIterator { - public: - AlgoIterator(const ConvArgs& args) : args_(args) {} - - static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results) { - perf_results.resize(1); - perf_results[0].algo = AlgoSearch::DEFAULT_ALGO; - if (args.params.data_type == CUDNN_DATA_HALF) { - perf_results[0].mathType = CUDNN_TENSOR_OP_MATH; - } else { - perf_results[0].mathType = CUDNN_DEFAULT_MATH; - } - CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(args, perf_results[0].algo, &(perf_results[0].memory))); - return Status::OK(); - } - - Status TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, std::function f) { - auto& cache = AlgoSearch::Cache(); - - if (T_Perf algo_perf; cache.Find(args_.params, &algo_perf) && f(algo_perf) == Status::OK()) { - return Status::OK(); - } - - std::vector perf_results; - ORT_RETURN_IF_ERROR(args_.params.algo_mode == OrtCudnnConvAlgoSearchDefault - ? OnlyDefaultAlgorithm(args_, perf_results) - : AlgoSearch::FindAlgorithms(args_, provider, allocator, perf_results)); - for (auto& algo_perf : perf_results) { - if (f(algo_perf) == Status::OK()) { - cache.Insert(args_.params, algo_perf); - return Status::OK(); - } - } - ORT_ENFORCE(false, "Unable to find a valid cuDNN algorithm to run convolution."); - return Status::OK(); - } - - private: - const ConvArgs& args_; -}; - template Status ConvGrad::PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor& w, Tensor* dB, Tensor* dX, Tensor* dW, cudnnHandle_t cudnn_handle) const { diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.h b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.h index 5d0c123fd9..9bbcd5b30d 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.h +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.h @@ -3,47 +3,11 @@ #pragma once -#include "core/providers/cuda/cudnn_common.h" -#include "core/providers/cpu/nn/conv_attributes.h" -#include "core/providers/cuda/nn/conv.h" +#include "orttraining/training_ops/cuda/nn/conv_shared.h" namespace onnxruntime { namespace cuda { -// cuDNN only takes 4D or 5D x tensor. -static constexpr int MAX_DIM = 3; - -struct ConvParams { - int8_t device_id; - cudnnDataType_t data_type; - int input_size[2 + MAX_DIM]; - uint8_t input_dim; - int weight_size[2 + MAX_DIM]; - int padding[MAX_DIM * 2]; - int stride[MAX_DIM]; - int dilation[MAX_DIM]; - int64_t groups; - int algo_mode; -}; - -struct ConvArgs { - // Update needed if x or w's dims changed. - TensorShapeVector last_x_dims; - TensorShapeVector last_w_dims; - - cudnnHandle_t handle; - ConvParams params; - CudnnTensor x_tensor, y_tensor, b_tensor; - CudnnFilterDescriptor w_desc; - CudnnConvolutionDescriptor conv_desc; - const void* x_data; - const void* w_data; - const void* dy_data; - void* dx_data; - void* dw_data; - void* db_data; -}; - template class ConvGrad final : public CudaKernel { public: diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc new file mode 100644 index 0000000000..5dc16c68f6 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc @@ -0,0 +1,275 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cuda/nn/conv_shared.h" + +#include "core/platform/ort_mutex.h" +#include "core/providers/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime::cuda { + +namespace { + +cudnnStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdDataAlgo algo, size_t* workspace_size) { + return cudnnGetConvolutionBackwardDataWorkspaceSize(args.handle, args.w_desc, args.y_tensor, args.conv_desc, + args.x_tensor, algo, workspace_size); +} + +cudnnStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdFilterAlgo algo, size_t* workspace_size) { + return cudnnGetConvolutionBackwardFilterWorkspaceSize(args.handle, args.x_tensor, args.y_tensor, args.conv_desc, + args.w_desc, algo, workspace_size); +} + +cudnnStatus_t GetWorkspaceSize(const ConvArgs& args, T_FwdAlgo algo, size_t* workspace_size) { + return cudnnGetConvolutionForwardWorkspaceSize(args.handle, args.x_tensor, args.w_desc, args.conv_desc, + args.y_tensor, algo, workspace_size); +} + +template +size_t GetMaxWorkspaceSize(const ConvArgs& args, const T_Algo* algo, int n_algo) { + // Calling cudaMemGetInfo is not ideal, but our cuda allocator doesn't have a way to get this info. + size_t free, total; + CUDA_CALL_THROW(cudaMemGetInfo(&free, &total)); + // Assuming 10% of fragmentation. + free = static_cast(static_cast(free) * 0.9); + size_t max_workspace_size = 0; + for (int i = 0; i < n_algo; i++) { + cudnnStatus_t status; + size_t workspace_size; + status = GetWorkspaceSize(args, algo[i], &workspace_size); + if (CUDNN_STATUS_SUCCESS != status || workspace_size == 0 || workspace_size < max_workspace_size || + workspace_size > free) + continue; + max_workspace_size = workspace_size; + } + + return max_workspace_size; +} + +template +std::vector GetValidAlgorithms(const T_Perf* perf_results, int n_algo) { + std::vector result; + result.reserve(n_algo); + for (int i = 0; i < n_algo; i++) { + T_Perf perf = perf_results[i]; + if (perf.status == CUDNN_STATUS_SUCCESS) { + result.emplace_back(perf); + } + } + ORT_ENFORCE(result.size() > 0, "No valid convolution algorithms available in CuDNN"); + // TODO: This is a cuDNN bug that gave wrong results in certain strided convolution gradient setups + // when cuDNN version < 7.5. Need to add handling for such special case. + return result; +} + +template +struct AlgoPerfCache { + mutable OrtMutex mutex; + std::unordered_map map; + + bool Find(const ConvParams& params, T_Perf* result) { + std::lock_guard guard(mutex); + auto it = map.find(params); + if (it == map.end()) { + return false; + } + *result = it->second; + return true; + } + + void Insert(const ConvParams& params, const T_Perf& algo_perf) { + std::lock_guard guard(mutex); + map[params] = algo_perf; + } +}; + +// TODO: Currently we use global AlgoPerfCache for ConvGrad and ConvTransposeGrad only. +// Conv's perf cache is still per node. +// Need to apply such global cache for Conv, and move some shared code from here to conv.h/cc. +AlgoPerfCache bwd_data_algos; +AlgoPerfCache bwd_filter_algos; +AlgoPerfCache fwd_algos; + +template +struct AlgoSearch {}; + +template <> +struct AlgoSearch { + static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + static AlgoPerfCache& Cache() { return bwd_data_algos; } + static Status FindAlgorithms(const ConvArgs& args, const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, + std::vector& perf_results) { + static const T_BwdDataAlgo algos[] = { + CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED}; + static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; + ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward data algorithms."); + int perf_count; + std::unique_ptr candidates = std::make_unique(num_algos); + if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { + CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionBackwardDataAlgorithm_v7(args.handle, args.w_desc, args.y_tensor, + args.conv_desc, args.x_tensor, num_algos, + &perf_count, candidates.get())); + } else if (args.params.algo_mode == OrtCudnnConvAlgoSearchExhaustive) { + size_t max_workspace_size = provider->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos) + : AlgoSearchWorkspaceSize; + // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. + // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. + IAllocatorUniquePtr workspace = max_workspace_size == 0 ? nullptr : IAllocator::MakeUniquePtr(allocator, max_workspace_size, true); + CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardDataAlgorithmEx( + args.handle, args.w_desc, args.w_data, args.y_tensor, args.dy_data, args.conv_desc, args.x_tensor, + args.dx_data, num_algos, &perf_count, candidates.get(), workspace.get(), max_workspace_size)); + } else { + ORT_ENFORCE(false, "Algo mode should be EXHAUSTIVE (0) or HEURISTIC (1), but got ", args.params.algo_mode); + } + perf_results = GetValidAlgorithms(candidates.get(), perf_count); + return Status::OK(); + } +}; + +template <> +struct AlgoSearch { + static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; + static AlgoPerfCache& Cache() { return bwd_filter_algos; } + static Status FindAlgorithms(const ConvArgs& args, const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, + std::vector& perf_results) { + static const T_BwdFilterAlgo algos[] = { + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, + }; + + // NOTE: - 1 because ALGO_WINOGRAD is not implemented. + static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT - 1; + ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms."); + std::unique_ptr candidates = std::make_unique(num_algos); + int perf_count; + if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { + CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm_v7(args.handle, args.x_tensor, args.y_tensor, + args.conv_desc, args.w_desc, num_algos, + &perf_count, candidates.get())); + } else if (args.params.algo_mode == OrtCudnnConvAlgoSearchExhaustive) { + size_t max_workspace_size = provider->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos) + : AlgoSearchWorkspaceSize; + // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. + // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. + IAllocatorUniquePtr workspace = max_workspace_size == 0 ? nullptr : IAllocator::MakeUniquePtr(allocator, max_workspace_size, true); + CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardFilterAlgorithmEx( + args.handle, args.x_tensor, args.x_data, args.y_tensor, args.dy_data, args.conv_desc, args.w_desc, + args.dw_data, num_algos, &perf_count, candidates.get(), workspace.get(), max_workspace_size)); + } else { + ORT_ENFORCE(false, "Algo mode should be EXHAUSTIVE (0) or HEURISTIC (1), but got ", args.params.algo_mode); + } + perf_results = GetValidAlgorithms(candidates.get(), perf_count); + return Status::OK(); + } +}; + +template <> +struct AlgoSearch { + static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + static AlgoPerfCache& Cache() { return fwd_algos; } + static Status FindAlgorithms(const ConvArgs& args, const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, + std::vector& perf_results) { + static const T_FwdAlgo algos[] = { + CUDNN_CONVOLUTION_FWD_ALGO_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_FFT, + CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, + }; + + static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; + ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms."); + std::unique_ptr candidates = std::make_unique(num_algos); + int perf_count; + if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { + CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardAlgorithm_v7(args.handle, args.x_tensor, args.w_desc, + args.conv_desc, args.y_tensor, num_algos, + &perf_count, candidates.get())); + } else if (args.params.algo_mode == OrtCudnnConvAlgoSearchExhaustive) { + size_t max_workspace_size = provider->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos) + : AlgoSearchWorkspaceSize; + // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. + // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. + IAllocatorUniquePtr workspace = max_workspace_size == 0 + ? nullptr + : IAllocator::MakeUniquePtr(allocator, max_workspace_size, true); + CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionForwardAlgorithmEx( + args.handle, args.x_tensor, args.x_data, args.w_desc, args.w_data, args.conv_desc, args.y_tensor, + args.y_data, num_algos, &perf_count, candidates.get(), workspace.get(), max_workspace_size)); + } else { + ORT_ENFORCE(false, "Algo mode should be EXHAUSTIVE (0) or HEURISTIC (1), but got ", args.params.algo_mode); + } + perf_results = GetValidAlgorithms(candidates.get(), perf_count); + return Status::OK(); + } +}; + +} // namespace + +size_t ConvParamsHash::operator()(const ConvParams& conv_params) const { + auto ptr = reinterpret_cast(&conv_params); + uint32_t value = 0x811C9DC5; + for (int i = 0; i < static_cast(sizeof(ConvParams)); ++i) { + value ^= ptr[i]; + value *= 0x01000193; + } + return static_cast(value); +} + +bool ConvParamsEqual::operator()(const ConvParams& a, const ConvParams& b) const { + auto ptr1 = reinterpret_cast(&a); + auto ptr2 = reinterpret_cast(&b); + return memcmp(ptr1, ptr2, sizeof(ConvParams)) == 0; +} + +template +Status AlgoIterator::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results) { + perf_results.resize(1); + perf_results[0].algo = AlgoSearch::DEFAULT_ALGO; + if (args.params.data_type == CUDNN_DATA_HALF) { + perf_results[0].mathType = CUDNN_TENSOR_OP_MATH; + } else { + perf_results[0].mathType = CUDNN_DEFAULT_MATH; + } + CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(args, perf_results[0].algo, &(perf_results[0].memory))); + return Status::OK(); +} + +template +Status AlgoIterator::TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, + std::function f) { + auto& cache = AlgoSearch::Cache(); + + if (T_Perf algo_perf; cache.Find(args_.params, &algo_perf) && f(algo_perf) == Status::OK()) { + return Status::OK(); + } + + std::vector perf_results; + ORT_RETURN_IF_ERROR(args_.params.algo_mode == OrtCudnnConvAlgoSearchDefault + ? OnlyDefaultAlgorithm(args_, perf_results) + : AlgoSearch::FindAlgorithms(args_, provider, allocator, perf_results)); + for (auto& algo_perf : perf_results) { + if (f(algo_perf) == Status::OK()) { + cache.Insert(args_.params, algo_perf); + return Status::OK(); + } + } + ORT_ENFORCE(false, "Unable to find a valid cuDNN algorithm to run convolution."); + return Status::OK(); +} + +template class AlgoIterator; +template class AlgoIterator; +template class AlgoIterator; + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h new file mode 100644 index 0000000000..a2d4bf3bdc --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/nn/conv.h" + +// The AlgoPerfCache and AlgoSearch here for Conv/ConvGrad/ConvTransposeGrad is adapted from PyTorch's implementation +// in aten/src/ATen/native/cudnn/Conv_v7.cpp. + +namespace onnxruntime::cuda { + +using T_BwdDataPerf = cudnnConvolutionBwdDataAlgoPerf_t; +using T_BwdDataAlgo = cudnnConvolutionBwdDataAlgo_t; +using T_BwdFilterPerf = cudnnConvolutionBwdFilterAlgoPerf_t; +using T_BwdFilterAlgo = cudnnConvolutionBwdFilterAlgo_t; +using T_FwdAlgo = cudnnConvolutionFwdAlgo_t; +using T_FwdPerf = cudnnConvolutionFwdAlgoPerf_t; + +// cuDNN only takes 4D or 5D x tensor. +static constexpr int MAX_DIM = 3; + +struct ConvParams { + int8_t device_id; + cudnnDataType_t data_type; + int input_size[2 + MAX_DIM]; + uint8_t input_dim; + int weight_size[2 + MAX_DIM]; + int padding[MAX_DIM * 2]; + int stride[MAX_DIM]; + int dilation[MAX_DIM]; + int64_t groups; + int algo_mode; +}; + +struct ConvArgs { + // Update needed if x or w's dims changed. + TensorShapeVector last_x_dims; // Input to the convolution + TensorShapeVector last_w_dims; // Weights of the convolution + + cudnnHandle_t handle; + ConvParams params; + CudnnTensor x_tensor, y_tensor, b_tensor; + CudnnFilterDescriptor w_desc; + CudnnConvolutionDescriptor conv_desc; + const void* x_data; + const void* w_data; + const void* dy_data; + void* y_data; + void* dx_data; + void* dw_data; + void* db_data; +}; + +struct ConvParamsHash { + // ConvParams must be a POD because we read out its memory constant as char* when hashing. + static_assert(std::is_pod::value, "ConvParams is not POD"); + + size_t operator()(const ConvParams& conv_params) const; +}; + +struct ConvParamsEqual { + // ConvParams must be a POD because we read out its memory constant as char* when hashing. + static_assert(std::is_pod::value, "ConvParams is not POD"); + + bool operator()(const ConvParams& a, const ConvParams& b) const; +}; + +template +class AlgoIterator { + public: + AlgoIterator(const ConvArgs& args) : args_(args) {} + + Status TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, + std::function f); + + static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results); + + private: + const ConvArgs& args_; +}; + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc new file mode 100644 index 0000000000..5f7206fc12 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc @@ -0,0 +1,308 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cuda/nn/conv_transpose_grad.h" + +namespace onnxruntime::cuda { + +#define REGISTER_CONVTRANSPOSE_GRADIENT_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX(ConvTransposeGrad, kMSDomain, 1, T, kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + ConvTransposeGrad); + +REGISTER_CONVTRANSPOSE_GRADIENT_KERNEL_TYPED(float) +REGISTER_CONVTRANSPOSE_GRADIENT_KERNEL_TYPED(double) +REGISTER_CONVTRANSPOSE_GRADIENT_KERNEL_TYPED(MLFloat16) + +template +Status ConvTransposeGrad::ComputeInternal(OpKernelContext* context) const { + const Tensor* dY = context->Input(0); + const Tensor* X = context->Input(1); + const Tensor* W = context->Input(2); + Tensor* dX = context->Output(0, X->Shape()); + Tensor* dW = context->Output(1, W->Shape()); + Tensor* dB = context->Output(2, {W->Shape()[1] * conv_attrs_.group}); + + if (dX) { + ORT_RETURN_IF_ERROR(PrepareConvForwardArgs(*dY, *W, *dX, GetCudnnHandle(context), args_dx_)); + ORT_RETURN_IF_ERROR(ComputeInputGradient(context->GetComputeStream(), args_dx_)); + } + + if (dW || dB) { + ORT_RETURN_IF_ERROR(PrepareConvBackwardFilterArgs(*dY, *W, *X, dW, dB, GetCudnnHandle(context), args_dw_)); + if (dW) ORT_RETURN_IF_ERROR(ComputeWeightGradient(context->GetComputeStream(), args_dw_)); + if (dB) ORT_RETURN_IF_ERROR(ComputeBiasGradient(args_dw_)); + } + + return Status::OK(); +} + +template +Status ConvTransposeGrad::ComputeInputGradient(onnxruntime::Stream* stream, const ConvArgs& args) const { + return AlgoIterator(args).TryAll( + static_cast(Info().GetExecutionProvider()), + Info().GetAllocator(OrtMemType::OrtMemTypeDefault), + [&](const T_FwdPerf& algo_perf) -> Status { + const auto one = Consts::One; + const auto zero = Consts::Zero; + IAllocatorUniquePtr workspace = GetScratchBuffer(algo_perf.memory, stream); + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(args.conv_desc, algo_perf.mathType)); + CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward( + args.handle, &one, args.x_tensor, args.x_data, args.w_desc, args.w_data, args.conv_desc, + algo_perf.algo, workspace.get(), algo_perf.memory, &zero, args.y_tensor, args.y_data)); + return Status::OK(); + }); + return Status::OK(); +} + +template +Status ConvTransposeGrad::ComputeWeightGradient(onnxruntime::Stream* stream, const ConvArgs& args) const { + return AlgoIterator(args).TryAll( + static_cast(Info().GetExecutionProvider()), + Info().GetAllocator(OrtMemType::OrtMemTypeDefault), + [&](const T_BwdFilterPerf& algo_perf) -> Status { + const auto one = Consts::One; + const auto zero = Consts::Zero; + IAllocatorUniquePtr workspace = GetScratchBuffer(algo_perf.memory, stream); + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(args.conv_desc, algo_perf.mathType)); + CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardFilter( + args.handle, &one, args.x_tensor, args.x_data, args.y_tensor, args.dy_data, args.conv_desc, + algo_perf.algo, workspace.get(), algo_perf.memory, &zero, args.w_desc, args.dw_data)); + return Status::OK(); + }); + return Status::OK(); +} + +template +Status ConvTransposeGrad::ComputeBiasGradient(const ConvArgs& args) const { + const auto one = Consts::One; + const auto zero = Consts::Zero; + CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardBias(args.handle, &one, args.x_tensor, args.x_data, &zero, + args.b_tensor, args.db_data)); + return Status::OK(); +} + +template +Status ConvTransposeGrad::PrepareConvForwardArgs(const Tensor& X, const Tensor& W, + Tensor& Y, cudnnHandle_t cudnn_handle, + ConvArgs& args) const { + const TensorShape& x_shape = X.Shape(); + auto x_dims = x_shape.AsShapeVector(); + args.x_data = reinterpret_cast(X.template Data()); + + const TensorShape& w_shape = W.Shape(); + auto w_dims = w_shape.AsShapeVector(); + args.w_data = reinterpret_cast(W.template Data()); + + const TensorShape& y_shape = Y.Shape(); + auto y_dims = y_shape.AsShapeVector(); + args.y_data = reinterpret_cast(Y.template MutableData()); + + args.dy_data = nullptr; + args.db_data = nullptr; + args.dx_data = nullptr; + args.dw_data = nullptr; + + bool x_dims_changed = (args.last_x_dims != x_dims); + bool w_dims_changed = (args.last_w_dims != w_dims); + if (x_dims_changed || w_dims_changed) { + if (x_dims_changed) args.last_x_dims = x_dims; + if (w_dims_changed) args.last_w_dims = w_dims; + + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(&X, &W)); + + TensorShapeVector kernel_shape; + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(w_shape, kernel_shape)); + auto rank = kernel_shape.size(); + + ConvPadVector pads(conv_attrs_.pads); + if (pads.empty()) { + pads.resize(rank * 2, 0); + } + + TensorShapeVector dilations(conv_attrs_.dilations); + if (dilations.empty()) { + dilations.resize(rank, 1); + } + + TensorShapeVector strides(conv_attrs_.strides); + if (strides.empty()) { + strides.resize(rank, 1); + } + + const CUDAExecutionProvider* cuda_ep = + static_cast(this->Info().GetExecutionProvider()); + + if (rank < 2) { + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + x_dims.insert(x_dims.begin() + 2, 1); + y_dims.insert(y_dims.begin() + 2, 1); + w_dims.insert(w_dims.begin() + 2, 1); + pads.insert(pads.begin() + rank, 0); + pads.insert(pads.begin(), 0); + kernel_shape.insert(kernel_shape.begin(), 1); + strides.insert(strides.begin(), 1); + dilations.insert(dilations.begin(), 1); + } else { + x_dims.push_back(1); + y_dims.push_back(1); + w_dims.push_back(1); + pads.insert(pads.begin() + rank, 0); + pads.insert(pads.end(), 0); + kernel_shape.push_back(1); + strides.push_back(1); + dilations.push_back(1); + } + } + + memset(&args.params, 0, sizeof(ConvParams)); + args.params.device_id = static_cast(cuda_ep->GetDeviceId()); + args.params.data_type = CudnnTensor::GetDataType(); + args.params.input_dim = static_cast(x_dims.size()); + for (size_t i = 0; i < x_dims.size(); i++) { + args.params.input_size[i] = static_cast(x_dims[i]); + args.params.weight_size[i] = static_cast(w_dims[i]); + } + for (size_t i = 0; i < rank; i++) { + args.params.padding[i] = static_cast(pads[i]); + args.params.padding[i + rank] = static_cast(pads[i + rank]); + args.params.stride[i] = static_cast(strides[i]); + args.params.dilation[i] = static_cast(dilations[i]); + } + args.params.groups = conv_attrs_.group; + int algo_mode = cuda_ep->GetCudnnConvAlgo(); + ORT_ENFORCE(algo_mode > -1 && algo_mode < 3, + "Algo mode should be EXHAUSTIVE (0), HEURISTIC (1) or DEFAULT (2), but got ", algo_mode); + args.params.algo_mode = algo_mode; + + args.handle = cudnn_handle; + ORT_RETURN_IF_ERROR(args.w_desc.Set(w_dims, args.params.data_type)); + ORT_RETURN_IF_ERROR(args.x_tensor.Set(x_dims, args.params.data_type)); + ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type)); + ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, + gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, + args.params.data_type)); + } + + return Status::OK(); +} + +template +Status ConvTransposeGrad::PrepareConvBackwardFilterArgs(const Tensor& X, const Tensor& W, const Tensor& dY, + Tensor* dW, Tensor* dB, cudnnHandle_t cudnn_handle, + ConvArgs& args) const { + const TensorShape& x_shape = X.Shape(); + auto x_dims = x_shape.AsShapeVector(); + args.x_data = reinterpret_cast(X.template Data()); + + const TensorShape& y_shape = dY.Shape(); + auto y_dims = y_shape.AsShapeVector(); + args.dy_data = reinterpret_cast(dY.template Data()); + + const TensorShape& w_shape = W.Shape(); + auto w_dims = w_shape.AsShapeVector(); + + args.y_data = nullptr; + args.dw_data = dW ? reinterpret_cast(dW->template MutableData()) : nullptr; + args.db_data = dB ? reinterpret_cast(dB->template MutableData()) : nullptr; + args.dx_data = nullptr; + args.w_data = nullptr; + + bool x_dims_changed = (args.last_x_dims != x_dims); + bool w_dims_changed = (args.last_w_dims != w_dims); + if (x_dims_changed || w_dims_changed) { + if (x_dims_changed) args.last_x_dims = x_dims; + if (w_dims_changed) args.last_w_dims = w_dims; + + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(&X, &W)); + + TensorShapeVector kernel_shape; + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(w_shape, kernel_shape)); + auto rank = kernel_shape.size(); + + ConvPadVector pads(conv_attrs_.pads); + if (pads.empty()) { + pads.resize(rank * 2, 0); + } + + TensorShapeVector dilations(conv_attrs_.dilations); + if (dilations.empty()) { + dilations.resize(rank, 1); + } + + TensorShapeVector strides(conv_attrs_.strides); + if (strides.empty()) { + strides.resize(rank, 1); + } + + const CUDAExecutionProvider* cuda_ep = + static_cast(this->Info().GetExecutionProvider()); + + if (rank < 2) { + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + x_dims.insert(x_dims.begin() + 2, 1); + y_dims.insert(y_dims.begin() + 2, 1); + w_dims.insert(w_dims.begin() + 2, 1); + pads.insert(pads.begin() + rank, 0); + pads.insert(pads.begin(), 0); + kernel_shape.insert(kernel_shape.begin(), 1); + strides.insert(strides.begin(), 1); + dilations.insert(dilations.begin(), 1); + } else { + x_dims.push_back(1); + y_dims.push_back(1); + w_dims.push_back(1); + pads.insert(pads.begin() + rank, 0); + pads.insert(pads.end(), 0); + kernel_shape.push_back(1); + strides.push_back(1); + dilations.push_back(1); + } + } + + memset(&args.params, 0, sizeof(ConvParams)); + args.params.device_id = static_cast(cuda_ep->GetDeviceId()); + args.params.data_type = CudnnTensor::GetDataType(); + args.params.input_dim = static_cast(x_dims.size()); + for (size_t i = 0; i < x_dims.size(); i++) { + args.params.input_size[i] = static_cast(x_dims[i]); + args.params.weight_size[i] = static_cast(w_dims[i]); + } + for (size_t i = 0; i < rank; i++) { + args.params.padding[i] = static_cast(pads[i]); + args.params.padding[i + rank] = static_cast(pads[i + rank]); + args.params.stride[i] = static_cast(strides[i]); + args.params.dilation[i] = static_cast(dilations[i]); + } + args.params.groups = conv_attrs_.group; + int algo_mode = cuda_ep->GetCudnnConvAlgo(); + ORT_ENFORCE(algo_mode > -1 && algo_mode < 3, + "Algo mode should be EXHAUSTIVE (0), HEURISTIC (1) or DEFAULT (2), but got ", algo_mode); + args.params.algo_mode = algo_mode; + + args.handle = cudnn_handle; + ORT_RETURN_IF_ERROR(args.w_desc.Set(w_dims, args.params.data_type)); + ORT_RETURN_IF_ERROR(args.x_tensor.Set(x_dims, args.params.data_type)); + ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type)); + ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, + gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, + args.params.data_type)); + + if (dB) { + const auto& b_shape = dB->Shape(); + ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); + TensorShapeVector b_dims(2 + kernel_shape.size()); + b_dims[0] = 1; // N + b_dims[1] = b_shape[0]; // C + for (size_t i = 0; i < kernel_shape.size(); i++) + b_dims[2 + i] = 1; + + ORT_RETURN_IF_ERROR(args.b_tensor.Set(b_dims, CudnnTensor::GetDataType())); + } + } + + return Status::OK(); +} + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.h b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.h new file mode 100644 index 0000000000..72426323fe --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cuda/cuda_kernel.h" + +#include "core/providers/cpu/nn/conv_attributes.h" +#include "orttraining/training_ops/cuda/nn/conv_shared.h" + +namespace onnxruntime::cuda { + +template +class ConvTransposeGrad final : public CudaKernel { + public: + using CudaT = typename ToCudaType::MappedType; + + ConvTransposeGrad(const OpKernelInfo& info) : CudaKernel(info), conv_attrs_(info) { + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + Status ComputeWeightGradient(onnxruntime::Stream* stream, const ConvArgs& args) const; + Status ComputeInputGradient(onnxruntime::Stream* stream, const ConvArgs& args) const; + Status ComputeBiasGradient(const ConvArgs& args) const; + + Status PrepareConvForwardArgs(const Tensor& X, const Tensor& W, + Tensor& Y, cudnnHandle_t cudnn_handle, + ConvArgs& args) const; + + Status PrepareConvBackwardFilterArgs(const Tensor& X, const Tensor& W, const Tensor& dY, + Tensor* dW, Tensor* dB, cudnnHandle_t cudnn_handle, + ConvArgs& args) const; + + ConvAttributes conv_attrs_; + mutable ConvArgs args_dx_; + mutable ConvArgs args_dw_; +}; + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.cc b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.cc new file mode 100644 index 0000000000..a5e8f7cd35 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.cc @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "orttraining/training_ops/cuda/tensor/resize_grad.h" +#include "orttraining/training_ops/cuda/tensor/resize_grad_impl.h" + +namespace onnxruntime::cuda { + +#define REGISTER_RESIZEGRAD_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + ResizeGrad, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) /* Keep roi on CPU */ \ + .InputMemoryType(OrtMemTypeCPUInput, 3) /* Keep scales on CPU */ \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + ResizeGrad); + +REGISTER_RESIZEGRAD_KERNEL_TYPED(MLFloat16) +REGISTER_RESIZEGRAD_KERNEL_TYPED(float) +REGISTER_RESIZEGRAD_KERNEL_TYPED(double) + +template +Status ResizeGrad::ComputeInternal(OpKernelContext* context) const { + typedef typename ToCudaType::MappedType CudaT; + + const Tensor* dY = context->Input(0); + const Tensor* X = context->Input(1); + const Tensor* scales = context->Input(3); + + ORT_ENFORCE(X->Shape().NumDimensions() == 4, "Expected input tensor to have 4 dimensions. Actual: ", + X->Shape().NumDimensions()); + + const auto get_scales_from_input = [](const Tensor* scales) { + if (nullptr == scales) { + return std::make_pair(std::optional{}, std::optional{}); + } + + ORT_ENFORCE(scales->Shape().Size() == 4, "There must be a scale for each dimension."); + + const auto* scales_data = scales->Data(); + return std::make_pair(std::optional{scales_data[2]}, std::optional{scales_data[3]}); + }; + + std::pair, std::optional> scale_factors = get_scales_from_input(scales); + + Tensor* dX = context->Output(0, X->Shape()); + + const int64_t batch_size = X->Shape()[0]; + const int64_t num_channels = X->Shape()[1]; + const int64_t output_height = dY->Shape()[2]; + const int64_t output_width = dY->Shape()[3]; + const int64_t input_height = X->Shape()[2]; + const int64_t input_width = X->Shape()[3]; + + if (dX->Shape() == dY->Shape()) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dX->MutableDataRaw(), dY->DataRaw(), dY->SizeInBytes(), cudaMemcpyDeviceToDevice)); + return Status::OK(); + } + + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(dX->MutableDataRaw(), 0, dX->SizeInBytes(), Stream(context))); + + const bool align_corners = coordinate_transform_mode_ == ResizeCoordinateTransformationMode::ALIGN_CORNERS; + const CudaT* dy_data = reinterpret_cast(dY->Data()); + CudaT* dx_data = reinterpret_cast(dX->MutableData()); + + ResizeGradImpl(Stream(context), input_height, input_width, output_height, + output_width, batch_size, num_channels, align_corners, + scale_factors.first, scale_factors.second, + dy_data, dx_data); + + return Status::OK(); +} + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.h b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.h new file mode 100644 index 0000000000..53f8d5f0d7 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cpu/tensor/upsamplebase.h" + +namespace onnxruntime::cuda { + +template +class ResizeGrad final : public UpsampleBase, public CudaKernel { + public: + ResizeGrad(const OpKernelInfo& info) : UpsampleBase(info), CudaKernel(info) { + ORT_ENFORCE(!antialias_, "Antialiasing is not supported in ResizeGrad yet."); + + ORT_ENFORCE(axes_.empty(), "ReizeGrad does not support the `axes` attribute yet."); + + std::string coordinate_transform_mode = + info.GetAttrOrDefault("coordinate_transformation_mode", "half_pixel"); + coordinate_transform_mode_ = StringToCoordinateTransformationMode(coordinate_transform_mode); + ORT_ENFORCE(coordinate_transform_mode_ == ResizeCoordinateTransformationMode::HALF_PIXEL || + coordinate_transform_mode_ == ResizeCoordinateTransformationMode::ALIGN_CORNERS, + "ReizeGrad only supports the `HALF_PIXEL` and `ALIGN_CORNERS` coordinate_transform_mode ", + coordinate_transform_mode, " is not supported yet."); + + ORT_ENFORCE(keep_aspect_ratio_policy_ == AspectRatioPolicy::STRETCH, + "ReizeGrad only supports the `STRETCH` policy."); + + std::string mode; + ORT_ENFORCE(info.GetAttr("mode", &mode).IsOK()); + ORT_ENFORCE((UpsampleMode::LINEAR == mode_), + "ReizeGrad only supports the `LINEAR` mode. ", mode, " mode is not supported yet."); + } + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.cu b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.cu new file mode 100644 index 0000000000..0507cda623 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.cu @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Contents of this file are derived from the pytorch cuda implementation of +// the upsample_bilinear2d_backward implementation at: +// https://github.com/pytorch/pytorch/blob/ce50132748f652ed6079c3db8008a6817594dbae/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu + +#include "orttraining/training_ops/cuda/tensor/resize_grad_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/atomic/common.cuh" + +namespace onnxruntime::cuda { + +namespace { + +constexpr int NumThreadsPerBlock = GridDim::maxThreadsPerBlock; + +} // namespace + +__device__ __forceinline__ size_t +idx(const size_t nc, + const size_t height, + const size_t width, + const size_t h, + const size_t w) { + return (nc * height + h) * width + w; +} + +template +__device__ __forceinline__ static T AreaPixelComputeSourceIndex( + T scale, + int dst_index, + bool align_corners, + bool cubic) { + if (align_corners) { + return scale * dst_index; + } else { + T src_idx = scale * (dst_index + static_cast(0.5)) - + static_cast(0.5); + return (!cubic && src_idx < static_cast(0)) + ? static_cast(0) + : src_idx; + } +} + +template +__global__ void UpsampleGrad(const int64_t nc, const int64_t input_height, + const int64_t input_width, const int64_t output_height, + const int64_t output_width, const AccT rheight, + const AccT rwidth, const bool align_corners, + const T* dY_data, T* dX_data) { + const size_t dy_numel = nc * output_width * output_height; + const size_t dx_numel = nc * input_width * input_height; + for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; + index < dy_numel; + index += blockDim.x * gridDim.x) { + size_t index_temp = index; + const int w2 = index_temp % output_width; // 0:width2-1 + index_temp /= output_width; + const int h2 = index_temp % output_height; // 0:height2-1 + const size_t nc = index_temp / output_height; + + const AccT h1r = AreaPixelComputeSourceIndex( + rheight, h2, align_corners, /*cubic=*/false); + const int h1 = h1r; + const int h1p = (h1 < input_height - 1) ? 1 : 0; + const AccT h1lambda = h1r - h1; + const AccT h0lambda = static_cast(1) - h1lambda; + + const AccT w1r = AreaPixelComputeSourceIndex( + rwidth, w2, align_corners, /*cubic=*/false); + const int w1 = w1r; + const int w1p = (w1 < input_width - 1) ? 1 : 0; + const AccT w1lambda = w1r - w1; + const AccT w0lambda = static_cast(1) - w1lambda; + + const T d2val = dY_data[index]; + AtomicAdd( + dX_data, + idx(nc, input_height, input_width, h1, w1), + dx_numel, + static_cast(h0lambda * w0lambda) * d2val); + AtomicAdd( + dX_data, + idx(nc, input_height, input_width, h1, w1 + w1p), + dx_numel, + static_cast(h0lambda * w1lambda) * d2val); + AtomicAdd( + dX_data, + idx(nc, input_height, input_width, h1 + h1p, w1), + dx_numel, + static_cast(h1lambda * w0lambda) * d2val); + AtomicAdd( + dX_data, + idx(nc, input_height, input_width, h1 + h1p, w1 + w1p), + dx_numel, + static_cast(h1lambda * w1lambda) * d2val); + } +} + +template +T AreaPixelComputeScale(int64_t input_size, int64_t output_size, bool align_corners, + const std::optional& scale) { + if (align_corners) { + if (output_size <= 1) { + return T{0}; + } + return static_cast(input_size - 1) / static_cast(output_size - 1); + } else { + if (scale.has_value()) { + return static_cast(T{1.0} / *scale); + } else { + return static_cast(input_size) / static_cast(output_size); + } + } +} + +template +void ResizeGradImpl(cudaStream_t stream, int64_t input_height, + int64_t input_width, int64_t output_height, + int64_t output_width, int64_t batch_size, + int64_t channels, bool align_corners, + const std::optional& scale_height, + const std::optional& scale_width, + const T* dY_data, T* dX_data) { + float rheight = AreaPixelComputeScale(input_height, output_height, align_corners, scale_height); + float rwidth = AreaPixelComputeScale(input_width, output_width, align_corners, scale_width); + + const size_t output_numel = batch_size * channels * output_height * output_width; + int blocks_per_grid = (int)(ceil(static_cast(output_numel) / NumThreadsPerBlock)); + UpsampleGrad<<>>( + batch_size * channels, input_height, input_width, output_height, output_width, + rheight, rwidth, align_corners, dY_data, dX_data); +} + +#define SPECIALIZED_RESIZEGRAD_IMPL(T) \ + template void ResizeGradImpl(cudaStream_t stream, int64_t input_height, \ + int64_t input_width, int64_t output_height, \ + int64_t output_width, int64_t batch_size, \ + int64_t channels, bool align_corners, \ + const std::optional& scale_height, \ + const std::optional& scale_width, \ + const T* dY_data, T* dX_data); + +SPECIALIZED_RESIZEGRAD_IMPL(half) +SPECIALIZED_RESIZEGRAD_IMPL(float) +SPECIALIZED_RESIZEGRAD_IMPL(double) + +#undef SPECIALIZED_RESIZEGRAD_IMPL + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.h b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.h new file mode 100644 index 0000000000..3e917f9071 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +namespace onnxruntime::cuda { + +template +void ResizeGradImpl(cudaStream_t stream, int64_t input_height, + int64_t input_width, int64_t output_height, + int64_t output_width, int64_t batch_size, + int64_t channels, bool align_corners, + const std::optional& scale_height, + const std::optional& scale_width, + const T* dY_data, T* dX_data); + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc index 2321aa23dd..e0749c2fb4 100644 --- a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc +++ b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc @@ -187,6 +187,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_BFloat16, ReduceAllL2); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PadAndUnflatten); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ResizeGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ResizeGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ResizeGrad); #if defined(ORT_USE_NCCL) || defined(USE_MPI) // P2P communication operators. @@ -387,6 +390,9 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // P2P communication operators. #if defined(ORT_USE_NCCL) || defined(USE_MPI) diff --git a/setup.py b/setup.py index c4bbb67947..458b161e9c 100644 --- a/setup.py +++ b/setup.py @@ -7,6 +7,7 @@ import datetime import logging import platform +import shlex import subprocess import sys from glob import glob, iglob @@ -183,108 +184,37 @@ def run(self): dest = "onnxruntime/capi/onnxruntime_pybind11_state_manylinux1.so" logger.info("copying %s -> %s", source, dest) copyfile(source, dest) - result = subprocess.run( - ["patchelf", "--print-needed", dest], check=True, stdout=subprocess.PIPE, text=True - ) - dependencies = [ - "librccl.so", - "libamdhip64.so", - "librocblas.so", - "libMIOpen.so", - "libhsa-runtime64.so", - "libhsakmt.so", - ] + to_preload = [] to_preload_cuda = [] to_preload_tensorrt = [] to_preload_cann = [] - cuda_dependencies = [] - args = ["patchelf", "--debug"] - for line in result.stdout.split("\n"): - for dependency in dependencies: - if dependency in line: - to_preload.append(line) - args.extend(["--remove-needed", line]) - args.append(dest) - if len(args) > 3: - subprocess.run(args, check=True, stdout=subprocess.PIPE) - - dest = "onnxruntime/capi/libonnxruntime_providers_" + ("rocm.so" if is_rocm else "cuda.so") - if path.isfile(dest): - result = subprocess.run( - ["patchelf", "--print-needed", dest], - check=True, - stdout=subprocess.PIPE, - text=True, - ) - cuda_dependencies = [ - "libcublas.so", - "libcublasLt.so", - "libcudnn.so", - "libcudart.so", - "libcurand.so", - "libcufft.so", - "libnvToolsExt.so", - "libcupti.so", - ] - rocm_dependencies = [ - "librccl.so", - "libamdhip64.so", - "librocblas.so", - "libMIOpen.so", - "libhsa-runtime64.so", - "libhsakmt.so", - ] - args = ["patchelf", "--debug"] - for line in result.stdout.split("\n"): - for dependency in cuda_dependencies + rocm_dependencies: - if dependency in line: - if dependency not in to_preload: - to_preload_cuda.append(line) - args.extend(["--remove-needed", line]) - args.append(dest) - if len(args) > 3: - subprocess.run(args, check=True, stdout=subprocess.PIPE) - - dest = "onnxruntime/capi/libonnxruntime_providers_" + ("migraphx.so" if is_rocm else "tensorrt.so") - if path.isfile(dest): - result = subprocess.run( - ["patchelf", "--print-needed", dest], - check=True, - stdout=subprocess.PIPE, - text=True, - ) - tensorrt_dependencies = ["libnvinfer.so", "libnvinfer_plugin.so", "libnvonnxparser.so"] - args = ["patchelf", "--debug"] - for line in result.stdout.split("\n"): - for dependency in cuda_dependencies + tensorrt_dependencies: - if dependency in line: - if dependency not in (to_preload + to_preload_cuda): - to_preload_tensorrt.append(line) - args.extend(["--remove-needed", line]) - args.append(dest) - if len(args) > 3: - subprocess.run(args, check=True, stdout=subprocess.PIPE) - - dest = "onnxruntime/capi/libonnxruntime_providers_cann.so" - if path.isfile(dest): - result = subprocess.run( - ["patchelf", "--print-needed", dest], - check=True, - stdout=subprocess.PIPE, - text=True, - ) - cann_dependencies = ["libascendcl.so", "libacl_op_compiler.so", "libfmk_onnx_parser.so"] - args = ["patchelf", "--debug"] - for line in result.stdout.split("\n"): - for dependency in cann_dependencies: - if dependency in line: - if dependency not in to_preload: - to_preload_cann.append(line) - args.extend(["--remove-needed", line]) - args.append(dest) - if len(args) > 3: - subprocess.run(args, check=True, stdout=subprocess.PIPE) + + cuda_dependencies = [ + "libcublas.so.11", + "libcublasLt.so.11", + "libcudnn.so.8", + "libcudart.so.11.0", + "libcurand.so.10", + "libcufft.so.10", + ] + rocm_dependencies = [ + "librccl.so.1", + "libnuma.so.1", + "libamd_comgr.so.2", + "libdrm.so.2", + "librocblas.so.0", + "libdrm_amdgpu.so.1", + "libamdhip64.so.5", + "libroctracer64.so.4", + "libMIOpen.so.1", + "libtinfo.so.6", + "libelf.so.1", + "librocm_smi64.so.5", + "libhsa-runtime64.so.1", + ] + + tensorrt_dependencies = ["libnvinfer.so.8", "libnvinfer_plugin.so.8", "libnvonnxparser.so.8"] dest = "onnxruntime/capi/libonnxruntime_providers_openvino.so" if path.isfile(dest): @@ -308,10 +238,12 @@ def run(self): assert self.dist_dir is not None file = glob(path.join(self.dist_dir, "*linux*.whl"))[0] logger.info("repairing %s for manylinux1", file) + auditwheel_cmd = ["auditwheel", "-v", "repair", "-w", self.dist_dir, file] + for i in cuda_dependencies + rocm_dependencies + tensorrt_dependencies: + auditwheel_cmd += ["--exclude", i] + logger.info("Running {}".format(" ".join([shlex.quote(arg) for arg in auditwheel_cmd]))) try: - subprocess.run( - ["auditwheel", "repair", "-w", self.dist_dir, file], check=True, stdout=subprocess.PIPE - ) + subprocess.run(auditwheel_cmd, check=True, stdout=subprocess.PIPE) finally: logger.info("removing %s", file) remove(file) @@ -470,6 +402,7 @@ def finalize_options(self): "onnxruntime.transformers.models.bart", "onnxruntime.transformers.models.bert", "onnxruntime.transformers.models.gpt2", + "onnxruntime.transformers.models.llama", "onnxruntime.transformers.models.longformer", "onnxruntime.transformers.models.t5", "onnxruntime.transformers.models.stable_diffusion", diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 64dae354a9..638196e73a 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -96,7 +96,7 @@ def _openvino_verify_device_type(device_read): break def invalid_hetero_build(): - print("\nIf trying to build Hetero/Multi/Auto, specifiy the supported devices along with it.\n") + print("\nIf trying to build Hetero/Multi/Auto, specify the supported devices along with it.\n") print("specify the keyword HETERO or MULTI or AUTO followed by the devices ") print("in the order of priority you want to build\n") print("The different hardware devices that can be added in HETERO or MULTI or AUTO") @@ -107,7 +107,7 @@ def invalid_hetero_build(): sys.exit("Wrong Build Type selected") if res is False: - print("\nYou have selcted wrong configuration for the build.") + print("\nYou have selected wrong configuration for the build.") print("pick the build type for specific Hardware Device from following options: ", choices) print("(or) from the following options with graph partitioning disabled: ", choices1) print("\n") @@ -166,6 +166,15 @@ def convert_arg_line_to_args(self, arg_line): help="Use parallel build. The optional value specifies the maximum number of parallel jobs. " "If the optional value is 0 or unspecified, it is interpreted as the number of CPUs.", ) + parser.add_argument( + "--nvcc_threads", + nargs="?", + default=-1, + type=int, + help="Maximum number of NVCC threads in each parallel job." + "If the value is unspecified, it will be computed based on available memory and number of parallel jobs.", + ) + parser.add_argument("--test", action="store_true", help="Run unit tests.") parser.add_argument("--skip_tests", action="store_true", help="Skip all tests.") parser.add_argument( @@ -422,7 +431,7 @@ def convert_arg_line_to_args(self, arg_line): parser.add_argument("--wasm_run_tests_in_browser", action="store_true", help="Run WebAssembly tests in browser") parser.add_argument( - "--enable_wasm_profiling", action="store_true", help="Enable WebAsselby profiling and preserve function names" + "--enable_wasm_profiling", action="store_true", help="Enable WebAssembly profiling and preserve function names" ) parser.add_argument( "--enable_wasm_debug_info", action="store_true", help="Build WebAssembly with DWARF format debug info" @@ -519,7 +528,7 @@ def convert_arg_line_to_args(self, arg_line): "--llvm_config", type=str, default="", - help="Path to llvm-config.exe for LLVM buit from sources. It is strongly needed for build on Windows", + help="Path to llvm-config.exe for LLVM built from sources. It is strongly needed for build on Windows", ) parser.add_argument( "--skip_onnx_tests", @@ -864,6 +873,43 @@ def normalize_arg_list(nested_list): return [i for j in nested_list for i in j] if nested_list else [] +def number_of_parallel_jobs(args): + return os.cpu_count() if args.parallel == 0 else args.parallel + + +def number_of_nvcc_threads(args): + if args.nvcc_threads >= 0: + return args.nvcc_threads + + nvcc_threads = 1 + try: + import psutil + + available_memory = psutil.virtual_memory().available + if isinstance(available_memory, int) and available_memory > 0: + if available_memory > 60 * 1024 * 1024 * 1024: + # When available memory is large enough, chance of OOM is small. + nvcc_threads = 4 + else: + # NVCC need a lot of memory to compile 8 flash attention cu files in Linux or 4 cutlass fmha cu files in Windows. + # Here we select number of threads to ensure each thread has enough memory (>= 4 GB). For example, + # Standard_NC4as_T4_v3 has 4 CPUs and 28 GB memory. When parallel=4 and nvcc_threads=2, + # total nvcc threads is 4 * 2, which is barely able to build in 28 GB memory so we will use nvcc_threads=1. + memory_per_thread = 4 * 1024 * 1024 * 1024 + fmha_cu_files = 4 if is_windows() else 16 + fmha_parallel_jobs = min(fmha_cu_files, number_of_parallel_jobs(args)) + nvcc_threads = max(1, int(available_memory / (memory_per_thread * fmha_parallel_jobs))) + print( + f"nvcc_threads={nvcc_threads} to ensure memory per thread >= 4GB for available_memory={available_memory} and fmha_parallel_jobs={fmha_parallel_jobs}" + ) + except ImportError: + print( + "Failed to import psutil. Please `pip install psutil` for better estimation of nvcc threads. Use nvcc_threads=1" + ) + + return nvcc_threads + + def generate_build_tree( cmake_path, source_dir, @@ -1028,7 +1074,8 @@ def generate_build_tree( if args.use_migraphx: cmake_args.append("-Donnxruntime_MIGRAPHX_HOME=" + migraphx_home) if args.use_cuda: - cmake_args.append("-Donnxruntime_NVCC_THREADS=" + str(args.parallel)) + nvcc_threads = number_of_nvcc_threads(args) + cmake_args.append("-Donnxruntime_NVCC_THREADS=" + str(nvcc_threads)) if args.use_rocm: cmake_args.append("-Donnxruntime_ROCM_HOME=" + rocm_home) cmake_args.append("-Donnxruntime_ROCM_VERSION=" + args.rocm_version) @@ -1782,13 +1829,12 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): [sys.executable, "onnxruntime_test_python_symbolic_shape_infer.py"], cwd=cwd, dll_path=dll_path ) - # For CUDA enabled builds test IOBinding feature - if args.use_cuda: - # We need to have Torch installed to test the IOBinding feature - # which currently uses Torch's allocator to allocate GPU memory for testing + # For CUDA or DML enabled builds test IOBinding feature + if args.use_cuda or args.use_dml: log.info("Testing IOBinding feature") run_subprocess([sys.executable, "onnxruntime_test_python_iobinding.py"], cwd=cwd, dll_path=dll_path) + if args.use_cuda: log.info("Testing CUDA Graph feature") run_subprocess([sys.executable, "onnxruntime_test_python_cudagraph.py"], cwd=cwd, dll_path=dll_path) @@ -2212,7 +2258,9 @@ def generate_documentation(source_dir, build_dir, configs, validate): have_diff = False def diff_file(path, regenerate_qualifiers=""): - diff = subprocess.check_output(["git", "diff", path], cwd=source_dir).decode("utf-8") + diff = subprocess.check_output(["git", "diff", "--ignore-blank-lines", path], cwd=source_dir).decode( + "utf-8" + ) if diff: nonlocal have_diff have_diff = True @@ -2240,6 +2288,8 @@ def main(): args = parse_arguments() + print(args) + if os.getenv("ORT_BUILD_WITH_CACHE") == "1": args.use_cache = True @@ -2525,7 +2575,7 @@ def main(): if args.build: if args.parallel < 0: raise BuildError(f"Invalid parallel job count: {args.parallel}") - num_parallel_jobs = os.cpu_count() if args.parallel == 0 else args.parallel + num_parallel_jobs = number_of_parallel_jobs(args) build_targets(args, cmake_path, build_dir, configs, num_parallel_jobs, args.target) if args.test: diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index ea998963d9..608112181b 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -195,7 +195,7 @@ stages: buildArch: x64 msbuildPlatform: x64 packageName: x64-cuda - buildparameter: --use_cuda --cuda_version=11.8 --cuda_home=$(Agent.TempDirectory)\v11.8 --enable_onnx_tests --enable_wcos --build_java --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" ${{parameters.AdditionalBuildFlag}} + buildparameter: --use_cuda --cuda_version=11.8 --cuda_home=$(Agent.TempDirectory)\v11.8 --enable_onnx_tests --enable_wcos --build_java --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80" ${{parameters.AdditionalBuildFlag}} runTests: ${{ parameters.RunOnnxRuntimeTests }} buildJava: true java_artifact_id: onnxruntime_gpu @@ -211,7 +211,7 @@ stages: buildArch: x64 msbuildPlatform: x64 packageName: x64-tensorrt - buildparameter: --use_tensorrt --tensorrt_home="C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8" --cuda_version=11.8 --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8" --enable_onnx_tests --enable_wcos --build_java --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + buildparameter: --use_tensorrt --tensorrt_home="C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8" --cuda_version=11.8 --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8" --enable_onnx_tests --enable_wcos --build_java --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80" runTests: ${{ parameters.RunOnnxRuntimeTests }} buildJava: true java_artifact_id: onnxruntime_gpu @@ -356,19 +356,32 @@ stages: - checkout: self submodules: false - template: templates/set-version-number-variables-step.yml - - task: DownloadPipelineArtifact@2 - displayName: 'Download Final Jar' - inputs: - buildType: 'current' - artifactName: 'onnxruntime-java-gpu' - targetPath: '$(Build.BinariesDirectory)/final-jar' - - task: Bash@3 + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Final Jar' + ArtifactName: onnxruntime-java-gpu + TargetPath: '$(Build.BinariesDirectory)/final-jar' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 + Context: tools/ci_build/github/linux/docker/ + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" + Repository: onnxruntimeubi8packagestest + UpdateDepsTxt: false + + - bash: | + docker run --rm \ + --gpus all \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + --volume /data/models:/build/models:ro \ + onnxruntimeubi8packagestest \ + /bin/bash /onnxruntime_src/tools/ci_build/github/linux/java_linux_final_test.sh -r /build -v $(OnnxRuntimeVersion) displayName: 'Test' - inputs: - targetType: filePath - filePath: 'tools/ci_build/github/linux/java_linux_final_test.sh' - arguments: '-r $(Build.BinariesDirectory) -v $(OnnxRuntimeVersion)' - template: templates/component-governance-component-detection-steps.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml index b784ef72d6..1e3d20b857 100644 --- a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml @@ -45,10 +45,6 @@ stages: clean: true submodules: none - - task: NodeTool@0 - inputs: - versionSpec: '16.x' - - task: UsePythonVersion@0 inputs: versionSpec: '3.8' @@ -74,11 +70,11 @@ stages: inputs: script: | mkdir -p $HOME/.onnx - mkdir -p $(Pipeline.Workspace)/ccache docker run --rm \ --volume /data/onnx:/data/onnx:ro \ --volume $(Build.SourcesDirectory):/onnxruntime_src \ --volume $(Build.BinariesDirectory):/build \ + --volume /data/models:/build/models:ro \ --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ --volume $(ORT_CACHE_DIR):/cache \ -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ @@ -90,135 +86,20 @@ stages: set -ex; \ ccache -s; \ /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ - --build_dir /build --cmake_generator Ninja \ + --build_dir /build --cmake_generator 'Ninja' \ --config Debug Release \ --skip_submodule_sync \ --build_shared_lib \ --parallel \ --build_wheel \ --build_csharp \ - --enable_onnx_tests \ - --enable_transformers_tool_test \ + --enable_onnx_tests --enable_symbolic_shape_infer_tests \ --use_cache \ - --build_java --build_nodejs --update --build --cmake_extra_defines onnxruntime_BUILD_BENCHMARKS=ON; \ + --build_java --build_nodejs --cmake_extra_defines onnxruntime_BUILD_BENCHMARKS=ON; \ ccache -sv; \ ccache -z" workingDirectory: $(Build.SourcesDirectory) - - task: UseDotNet@2 - displayName: "Setup dotnet" - inputs: - version: '6.0.408' - - - task: DotNetCoreCLI@2 - displayName: "Restore C# packages" - inputs: - command: 'restore' - projects: '$(Build.SourcesDirectory)/csharp/OnnxRuntime.DesktopOnly.CSharp.sln' - - # the props file was generated with docker container paths. convert to the 'real' path by replacing the - # the container path of '/build'. The '>' prefix is to match the closing angle bracket of the tag. - # e.g. /build/... so we only match the start of a path. - # We use powershell so we don't need extra escaping of the '/' chars in the path. - - task: CmdLine@2 - displayName: 'Update props from docker path to local and create models link' - inputs: - script: | - pwsh -Command '(Get-Content $(Build.SourcesDirectory)/csharp/Directory.Build.props) -replace ">/build", ">$(Build.BinariesDirectory)" | Set-Content $(Build.SourcesDirectory)/csharp/Directory.Build.props' - cat $(Build.SourcesDirectory)/csharp/Directory.Build.props - ln -s /data/models $(Build.BinariesDirectory)/models - - - task: DotNetCoreCLI@2 - displayName: 'dotnet build C# sln' - inputs: - command: 'build' - projects: '$(Build.SourcesDirectory)/csharp/OnnxRuntime.DesktopOnly.CSharp.sln' - - - task: DotNetCoreCLI@2 - displayName: 'dotnet test C#' - inputs: - command: 'test' - projects: '$(Build.SourcesDirectory)/csharp/OnnxRuntime.DesktopOnly.CSharp.sln' - # extra logging so all tests are listed in output to validate what's actually run - arguments: '-f net6.0 --no-build -l "console;verbosity=normal"' - workingDirectory: $(Build.SourcesDirectory)/csharp - - - task: CmdLine@2 - displayName: 'Install python deps and run java tests' - inputs: - script: | - set -e -x - python3 -m pip uninstall -y ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml ort-nightly-directml onnx -qq - cp $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt $(Build.BinariesDirectory)/requirements.txt - # Test ORT with the latest ONNX release. - sed -i "s/git+http:\/\/github\.com\/onnx\/onnx.*/onnx/" $(Build.BinariesDirectory)/requirements.txt - python3 -m pip install -r $(Build.BinariesDirectory)/requirements.txt - mkdir $(Build.BinariesDirectory)/requirements_torch_cpu/ - cp $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt $(Build.BinariesDirectory)/requirements_torch_cpu/requirements.txt - python3 -m pip install -r $(Build.BinariesDirectory)/requirements_torch_cpu/requirements.txt - cd $(Build.SourcesDirectory)/java - $(Build.SourcesDirectory)/java/gradlew "cmakeCheck" "-DcmakeBuildDir=$(Build.BinariesDirectory)/Release" - - - task: CmdLine@2 - displayName: 'Install Release python package' - inputs: - script: | - rm -rf $(Build.BinariesDirectory)/Release/onnxruntime $(Build.BinariesDirectory)/Release/pybind11 - python3 -m pip install $(Build.BinariesDirectory)/Release/dist/*.whl - - - task: PythonScript@0 - displayName: 'Run Release unit tests' - inputs: - scriptPath: $(Build.SourcesDirectory)/tools/ci_build/build.py - workingDirectory: $(Build.BinariesDirectory)/Release - arguments: >- - --build_dir $(Build.BinariesDirectory) - --cmake_generator Ninja - --config Release - --test - --skip_submodule_sync - --build_shared_lib - --parallel - --build_wheel - --enable_onnx_tests - --enable_transformers_tool_test - --build_nodejs - --ctest_path "" - - - task: CmdLine@2 - displayName: 'Install Debug python package' - inputs: - script: | - set -e -x - rm -rf $(Build.BinariesDirectory)/Debug/onnxruntime $(Build.BinariesDirectory)/Debug/pybind11 - python3 -m pip uninstall -y ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml ort-nightly-directml -qq - python3 -m pip install $(Build.BinariesDirectory)/Debug/dist/*.whl - - - task: PythonScript@0 - displayName: 'Run Debug unit tests' - inputs: - scriptPath: $(Build.SourcesDirectory)/tools/ci_build/build.py - workingDirectory: $(Build.BinariesDirectory)/Debug - arguments: >- - --build_dir $(Build.BinariesDirectory) - --cmake_generator Ninja - --config Debug - --test - --skip_submodule_sync - --build_shared_lib - --parallel - --build_wheel - --enable_onnx_tests - --enable_transformers_tool_test - --build_nodejs - --ctest_path "" - - - task: PythonScript@0 - displayName: 'Symbolic shape infer' - inputs: - scriptPath: $(Build.BinariesDirectory)/Release/onnxruntime_test_python_symbolic_shape_infer.py - workingDirectory: $(Build.BinariesDirectory)/Release - - task: PublishTestResults@2 displayName: 'Publish unit test results' inputs: @@ -245,8 +126,11 @@ stages: - stage: arm64_test dependsOn: ['arm64_build'] jobs: - - template: templates/py-packaging-linux-test.yml + - template: templates/py-packaging-linux-test-cpu.yml parameters: arch: 'aarch64' machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' - device: 'CPU' + base_image: 'arm64v8/centos:7' + devtoolset_rootpath: /opt/rh/devtoolset-10/root + ld_library_path_arg: /opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib:/opt/rh/devtoolset-10/root/usr/lib64/dyninst:/opt/rh/devtoolset-10/root/usr/lib/dyninst:/usr/local/lib64 + prepend_path: '/opt/rh/devtoolset-10/root/usr/bin:' diff --git a/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml index 5dc8fffbfe..461a62496c 100644 --- a/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml @@ -50,10 +50,6 @@ jobs: clean: true submodules: recursive - - task: NodeTool@0 - inputs: - versionSpec: '16.x' - - template: templates/get-docker-image-steps.yml parameters: Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_aten_cpu diff --git a/tools/ci_build/github/azure-pipelines/linux-multi-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-multi-gpu-tensorrt-ci-pipeline.yml index c1f1c39c85..f0c4422c7e 100644 --- a/tools/ci_build/github/azure-pipelines/linux-multi-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-multi-gpu-tensorrt-ci-pipeline.yml @@ -33,5 +33,3 @@ jobs: JobName: 'Linux_CI_Multi_GPU_TensorRT_Dev' # The latest TensorRT container only supports ubuntu20.04 and python 3.8 RunDockerBuildArgs: '-o ubuntu20.04 -d tensorrt -x "--enable_multi_device_test"' - DoNugetPack: 'false' - ArtifactName: 'drop-linux' diff --git a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml index 2938b87ec6..0264086c12 100644 --- a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml @@ -32,6 +32,4 @@ jobs: AgentPool : 'Linux-CPU-2019' JobName: 'Linux_CI_Dev' RunDockerBuildArgs: '-o ubuntu20.04 -d openvino -v 2023.0.0 -x "--use_openvino CPU_FP32 --build_wheel"' - DoNugetPack: 'false' - ArtifactName: 'drop-linux' TimeoutInMinutes: 120 diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index 53596a5ad5..12696e166a 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -103,6 +103,7 @@ jobs: ./build/Release/onnx_test_runner -e qnn \ -v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so" \ /data/qdq_models + enabled: false - task: CmdLine@2 displayName: Run QDQ model tests with context cache enabled diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml index 4563a79adb..864d1002a9 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml @@ -1,5 +1,5 @@ parameters: - AgentPool: 'onnxruntime-Ubuntu2004-AMD-CPU' + AgentPool: 'Azure-Pipelines-EO-Ubuntu-2004-aiinfra' StageSuffix: '' stages: - stage: Nodejs_Test_${{ parameters.StageSuffix }} @@ -18,4 +18,4 @@ stages: value: '$(Build.BinariesDirectory)' steps: - template: test.yml - \ No newline at end of file + diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml index cbe4e805bb..b07d9a6089 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml @@ -50,19 +50,52 @@ stages: script: | ln -sf /data/models $(Build.BinariesDirectory) - - task: Bash@3 - displayName: 'Run Package Test' - inputs: - targetType: filePath - filePath: '$(Build.SourcesDirectory)/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/runtest.sh' - arguments: '$(Build.BinariesDirectory)/nuget-artifact $(NuGetPackageVersionNumber)' - workingDirectory: $(Build.BinariesDirectory) - env: - OnnxRuntimeBuildDirectory: $(Build.BinariesDirectory) - DisableContribOps: $(DisableContribOps) - DisableMlOps: $(DisableMlOps) - IsReleaseBuild: $(IsReleaseBuild) - PACKAGENAME: ${{ parameters.NugetPackageName }} + - ${{if contains(parameters.StageSuffix , 'GPU') }}: + - template: ../../templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 + Context: tools/ci_build/github/linux/docker/ + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" + Repository: onnxruntimepackagestest + - bash: | + docker run --rm \ + --gpus all \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + --volume /data/models:/build/models:ro \ + -e BUILD_SOURCESDIRECTORY='/onnxruntime_src' \ + -e OnnxRuntimeBuildDirectory='/build' \ + -e DisableContribOps='$(DisableContribOps)' \ + -e DisableMlOps='$(DisableMlOps)' \ + -e IsReleaseBuild='$(IsReleaseBuild)' \ + -e PACKAGENAME='${{ parameters.NugetPackageName }}' \ + onnxruntimepackagestest \ + /bin/bash -c " + set -ex; \ + pushd /build; \ + bash /onnxruntime_src/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/runtest.sh /build/nuget-artifact $(NuGetPackageVersionNumber); \ + popd + " + displayName: 'Run Package Test' + - ${{ else }}: + - task: CmdLine@2 + displayName: 'Create symlink for test models' + inputs: + script: | + ln -sf /data/models $(Build.BinariesDirectory) + - task: Bash@3 + displayName: 'Run Package Test' + inputs: + targetType: filePath + filePath: '$(Build.SourcesDirectory)/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/runtest.sh' + arguments: '$(Build.BinariesDirectory)/nuget-artifact $(NuGetPackageVersionNumber)' + workingDirectory: $(Build.BinariesDirectory) + env: + OnnxRuntimeBuildDirectory: $(Build.BinariesDirectory) + DisableContribOps: $(DisableContribOps) + DisableMlOps: $(DisableMlOps) + IsReleaseBuild: $(IsReleaseBuild) + PACKAGENAME: ${{ parameters.NugetPackageName }} - template: ../../templates/component-governance-component-detection-steps.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml index f5b221f23f..355faa8b98 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml @@ -51,10 +51,6 @@ jobs: clean: true submodules: none - - task: NodeTool@0 - inputs: - versionSpec: '16.x' - - task: UsePythonVersion@0 inputs: versionSpec: '3.8' @@ -85,6 +81,7 @@ jobs: mkdir -p $(Pipeline.Workspace)/ccache docker run --rm \ --volume /data/onnx:/data/onnx:ro \ + --volume /data/models:/build/models:ro \ --volume $(Build.SourcesDirectory):/onnxruntime_src \ --volume $(Build.BinariesDirectory):/build \ --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ @@ -107,53 +104,11 @@ jobs: --enable_onnx_tests \ --enable_training \ --use_cache \ - --build_java --build_nodejs --update --build; \ + --build_java --build_nodejs; \ ccache -sv; \ ccache -z" workingDirectory: $(Build.SourcesDirectory) - - task: CmdLine@2 - displayName: 'Install python deps and run java tests' - inputs: - script: | - set -e -x - python3 -m pip uninstall -y ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml ort-nightly-directml onnx -qq - cp $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt $(Build.BinariesDirectory)/requirements.txt - # Test ORT with the latest ONNX release. - sed -i "s/git+http:\/\/github\.com\/onnx\/onnx.*/onnx/" $(Build.BinariesDirectory)/requirements.txt - python3 -m pip install -r $(Build.BinariesDirectory)/requirements.txt - mkdir $(Build.BinariesDirectory)/requirements_torch_cpu/ - cp $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt $(Build.BinariesDirectory)/requirements_torch_cpu/requirements.txt - python3 -m pip install -r $(Build.BinariesDirectory)/requirements_torch_cpu/requirements.txt - cd $(Build.SourcesDirectory)/java - $(Build.SourcesDirectory)/java/gradlew "cmakeCheck" "-DcmakeBuildDir=$(Build.BinariesDirectory)/Release" - - - task: CmdLine@2 - displayName: 'Install Release python package' - inputs: - script: | - rm -rf $(Build.BinariesDirectory)/Release/onnxruntime $(Build.BinariesDirectory)/Release/pybind11 - python3 -m pip install $(Build.BinariesDirectory)/Release/dist/*.whl - - - task: PythonScript@0 - displayName: 'Run Release unit tests' - inputs: - scriptPath: $(Build.SourcesDirectory)/tools/ci_build/build.py - workingDirectory: $(Build.BinariesDirectory)/Release - arguments: >- - --build_dir $(Build.BinariesDirectory) - --cmake_generator Ninja - --config Release - --test - --skip_submodule_sync - --build_shared_lib - --parallel - --build_wheel - --enable_onnx_tests - --enable_training - --build_nodejs - --ctest_path "" - - task: PublishTestResults@2 displayName: 'Publish unit test results' inputs: diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml index 16d70a58a0..da0a2a6026 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml @@ -18,7 +18,6 @@ jobs: parameters: AgentPool : 'Onnxruntime-Linux-GPU-NC6sv3' JobName: 'Onnxruntime_Linux_GPU_Training' - SubmoduleCheckoutMode: 'recursive' RunDockerBuildArgs: > -o ubuntu20.04 -d gpu -t onnxruntime_orttraining_ortmodule_tests_image @@ -26,24 +25,16 @@ jobs: -e -x " --enable_training - --config $(buildConfig) + --config Release --use_cuda --cuda_version=11.8 --cuda_home=/usr/local/cuda-11.8 --cudnn_home=/usr/local/cuda-11.8 --build_wheel --enable_nvtx_profile --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=70 " - DoNugetPack: 'false' RunInjectedPipeline: 'true' InjectedPipeline: 'orttraining-linux-gpu-ortmodule-test-ci-pipeline.yml' DockerImageTag: 'onnxruntime_orttraining_ortmodule_tests_image' - BuildConfig: $(buildConfig) - ArtifactName: 'drop-linux' TimeoutInMinutes: 140 # Enable unreleased onnx opsets in CI builds # This facilitates testing the implementation for the new opsets AllowReleasedOpsetOnly: '0' - Strategy: - maxParallel: 2 - matrix: - Release: - buildConfig: Release diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml index 8806707d21..ac551a53cd 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml @@ -1,14 +1,5 @@ trigger: none -variables: - - name: isMain - value: ${{ eq(variables['Build.SourceBranch'], 'refs/heads/main') }} - - name: finalStorage - ${{ if eq(variables['isMain'], 'true') }}: - value: '--final_storage' - ${{ else }}: - value: '' - resources: repositories: - repository: manylinux @@ -39,14 +30,6 @@ stages: PythonVersion: '3.11' steps: - - task: CmdLine@2 - displayName: 'check variables' - inputs: - script: | - echo "Branch is "${{ variables['Build.SourceBranch'] }} && \ - echo "isMain is "${{ variables['isMain'] }} && \ - echo "final_storage is "${{ variables['finalStorage'] }} - - checkout: self clean: true submodules: recursive @@ -102,17 +85,6 @@ stages: inputs: ArtifactName: onnxruntime_training_cpu - - task: CmdLine@2 - condition: succeeded() - displayName: 'Upload wheel' - inputs: - script: | - files=($(Build.ArtifactStagingDirectory)/Release/dist/*.whl) && \ - echo ${files[0]} && \ - echo ${{ variables['finalStorage'] }} && \ - tools/ci_build/upload_python_package_to_azure_storage.py \ - --python_wheel_path ${files[0]} ${{ variables['finalStorage'] }} - - template: templates/component-governance-component-detection-steps.yml parameters: condition: 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index 872812a6ae..b858770583 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -74,7 +74,6 @@ stages: isX86: false job_name_suffix: x64_RelWithDebInfo RunOnnxRuntimeTests: true - RunStaticCodeAnalysis: false ORT_EP_NAME: CUDA WITH_CACHE: true MachinePool: onnxruntime-Win2022-GPU-MultiA10 @@ -93,7 +92,6 @@ stages: isX86: false job_name_suffix: x64_mimalloc RunOnnxRuntimeTests: true - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false @@ -113,7 +111,6 @@ stages: isX86: false job_name_suffix: x64_no_memory_profiling RunOnnxRuntimeTests: false - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false @@ -133,7 +130,6 @@ stages: isX86: false job_name_suffix: x64_minimal_no_exception RunOnnxRuntimeTests: true - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false @@ -153,7 +149,6 @@ stages: isX86: false job_name_suffix: x64_debug_node_input_output RunOnnxRuntimeTests: true - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index c684e08ba1..5e01936626 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -3,24 +3,38 @@ resources: - pipeline: build source: 'Python packaging pipeline' trigger: true + branch: rel-1.16.0 # branch to pick the artifact, Used only for manual triggered pipeline runs for testing the pipeline itself + #TODO: Remove the following dependency. Running python tests should not need to use manylinux. + repositories: + - repository: manylinux # The name used to reference this repository in the checkout step + type: Github + endpoint: Microsoft + name: pypa/manylinux + ref: 5eda9aded5462201e6310105728d33016e637ea7 stages: - stage: Linux_Test_CPU_x86_64_stage jobs: - - template: templates/py-packaging-linux-test.yml + - template: templates/py-packaging-linux-test-cpu.yml parameters: arch: 'x86_64' machine_pool: 'onnxruntime-Ubuntu2004-AMD-CPU' - device: 'CPU' + base_image: 'centos:7' + devtoolset_rootpath: /opt/rh/devtoolset-11/root + ld_library_path_arg: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst:/usr/local/lib64 + prepend_path: '/opt/rh/devtoolset-11/root/usr/bin:' - stage: Linux_Test_CPU_aarch64_stage dependsOn: [] jobs: - - template: templates/py-packaging-linux-test.yml + - template: templates/py-packaging-linux-test-cpu.yml parameters: arch: 'aarch64' - machine_pool: 'aiinfra-linux-ARM64-CPU-2019' - device: 'CPU' + machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' + base_image: 'arm64v8/centos:7' + devtoolset_rootpath: /opt/rh/devtoolset-10/root + ld_library_path_arg: /opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib:/opt/rh/devtoolset-10/root/usr/lib64/dyninst:/opt/rh/devtoolset-10/root/usr/lib/dyninst:/usr/local/lib64 + prepend_path: '/opt/rh/devtoolset-10/root/usr/bin:' - stage: Packages_Somking_Test dependsOn: [] @@ -31,19 +45,6 @@ stages: machine_pool: vmImage: 'macOS-13' itemPattern: '*/*mac*x86_64.whl' - - template: templates/py-package-smoking-test.yml - parameters: - job_name: Test_WIN_64_Wheels - itemPattern: '*/*win_amd64.whl' - machine_pool: - vmImage: 'windows-2022' - - template: templates/py-package-smoking-test.yml - parameters: - job_name: Test_WIN_32_Wheels - itemPattern: '*/*win32.whl' - python_arch: 'x86' - machine_pool: - vmImage: 'windows-2022' - template: templates/py-package-smoking-test.yml parameters: job_name: Test_LINUX_x86_64_Wheels @@ -61,7 +62,7 @@ stages: - Linux_Test_CPU_aarch64_stage - Packages_Somking_Test jobs: - - template: templates/py-packaging-linux-test.yml + - template: templates/py-packaging-linux-test-cuda.yml parameters: arch: 'x86_64' machine_pool: 'Onnxruntime-Linux-GPU' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 74007d9b55..af245c9970 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -495,7 +495,7 @@ stages: PackageType: 'nuget' PackagePath: '$(Build.ArtifactStagingDirectory)' PackageName: 'Microsoft.ML.OnnxRuntime.*nupkg' - PlatformsSupported: 'win-x64,win-x86,linux-x64,linux-arm64,osx.10.14-x64' + PlatformsSupported: 'win-x64,win-x86,linux-x64,linux-arm64,osx-x64' VerifyNugetSigning: false - task: PublishPipelineArtifact@0 @@ -804,7 +804,7 @@ stages: - template: ../nodejs/templates/test_linux.yml parameters: - AgentPool : 'onnxruntime-Ubuntu2004-AMD-CPU' + AgentPool : 'Azure-Pipelines-EO-Ubuntu-2004-aiinfra' StageSuffix : 'Linux_CPU_x64' - template: ../nodejs/templates/test_macos.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/compliance.yml b/tools/ci_build/github/azure-pipelines/templates/compliance.yml index 04d999b556..0dfe398c8b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/compliance.yml +++ b/tools/ci_build/github/azure-pipelines/templates/compliance.yml @@ -18,27 +18,6 @@ steps: arguments: 'analyze $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\*.dll --recurse --verbose' continueOnError: true -- task: DeleteFiles@1 - displayName: 'Delete files from $(Build.BinariesDirectory)\RelWithDebInfo' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' - Contents: | - **/*.obj - **/*.pdb - **/*.dll - -# Manually set msBuildCommandline so that we can also set CAExcludePath -- task: SDLNativeRules@3 - displayName: 'Run the PREfast SDL Native Rules for MSBuild' - inputs: - userProvideBuildInfo: msBuildInfo - msBuildArchitecture: x64 - msBuildVersion: 17.0 - msBuildCommandline: '"C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Current\Bin\amd64\msbuild.exe" "$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln" /p:platform="${{parameters.msbuildPlatform}}" /p:configuration="RelWithDebInfo" /p:CAExcludePath="$(Build.BinariesDirectory);$(Build.SourcesDirectory)\cmake;C:\program files (x86)" /p:VisualStudioVersion="17.0" /m /p:PreferredToolArchitecture=x64' - excludedPaths: '$(Build.SourcesDirectory)\b#$(Build.SourcesDirectory)\cmake#C:\program files#C:\program files (x86)#C:\program files' - rulesetName: Custom - customRuleset: $(Build.SourcesDirectory)\cmake\Sdl.ruleset - - task: SdtReport@2 displayName: 'Create Security Analysis Report' inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index c653ba2992..e75f29b3da 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.77 + version: 1.0.112 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.77 + version: 1.0.112 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. diff --git a/tools/ci_build/github/azure-pipelines/templates/flex-downloadPipelineArtifact.yml b/tools/ci_build/github/azure-pipelines/templates/flex-downloadPipelineArtifact.yml index 0f4e0553d0..a83451a1b3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/flex-downloadPipelineArtifact.yml +++ b/tools/ci_build/github/azure-pipelines/templates/flex-downloadPipelineArtifact.yml @@ -18,7 +18,7 @@ parameters: steps: - task: DownloadPipelineArtifact@2 - displayName: ${{ parameters.StepName }}} + displayName: ${{ parameters.StepName }} inputs: artifactName: ${{ parameters.ArtifactName}} targetPath: '${{ parameters.TargetPath }}' diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml index 1cd21ea199..b05602a57b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml @@ -34,11 +34,6 @@ parameters: type: boolean default: true -- name: RunStaticCodeAnalysis - displayName: Run Static Code Analysis - type: boolean - default: true - - name: ORT_EP_NAME type: string @@ -105,7 +100,7 @@ jobs: - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' force32bit: ${{ parameters.isX86 }} @@ -309,49 +304,6 @@ jobs: workingDirectory: '$(Build.BinariesDirectory)\${{ parameters.BuildConfig }}\${{ parameters.BuildConfig }}' displayName: 'Run tests' - - - ${{ if eq(parameters.RunStaticCodeAnalysis, true) }}: - - task: DeleteFiles@1 - displayName: 'Delete binaries files from $(Build.BinariesDirectory)\RelWithDebInfo' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' - Contents: | - **/*.obj - **/*.pdb - **/*.dll - - - # Manually set msBuildCommandline so that we can also set CAExcludePath - # build_dir must be a sub folder of $(Build.SourcesDirectory) - # TODO: move this step to a CPU-only machine to save GPU resources. - - task: SDLNativeRules@3 - displayName: 'Run the PREfast SDL Native Rules for MSBuild' - inputs: - msBuildArchitecture: amd64 - setupCommandlines: 'python $(Build.SourcesDirectory)\tools\ci_build\build.py --config RelWithDebInfo --build_dir $(Build.SourcesDirectory)\b --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_onnx_tests ${{ parameters.additionalBuildFlags }} --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON onnxruntime_ENABLE_LTO=OFF' - msBuildCommandline: '"C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Current\Bin\amd64\msbuild.exe" "$(Build.SourcesDirectory)\b\RelWithDebInfo\onnxruntime.sln" /p:RunCodeAnalysis=true /p:platform=${{ parameters.msbuildPlatform }} /p:configuration=RelWithDebInfo /p:VisualStudioVersion="17.0" /m /p:PreferredToolArchitecture=x64' - excludedPaths: '$(Build.SourcesDirectory)\b#$(Build.SourcesDirectory)\cmake#C:\program files#C:\program files (x86)#C:\program files' - rulesetName: Custom - customRuleset: $(Build.SourcesDirectory)\cmake\Sdl.ruleset - publishXML: true - - - task: SdtReport@2 - displayName: 'Create Security Analysis Report' - inputs: - SDLNativeRules: true - - - task: PublishSecurityAnalysisLogs@3 - displayName: 'Publish Security Analysis Logs' - continueOnError: true - - - task: PostAnalysis@2 - displayName: 'Guardian Break v2' - inputs: - GdnBreakGdnToolSDLNativeRulesSeverity: Note - GdnBreakGdnToolSDLNativeRules: true - - - - ${{ if eq(parameters.RunOnnxRuntimeTests, true) }}: - task: PublishTestResults@2 displayName: 'Publish unit test results' inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml index 05b2dee77e..7b9788d90b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml @@ -1,23 +1,14 @@ parameters: AgentPool : 'onnxruntime-Ubuntu2004-AMD-CPU' StageName : 'Linux_CI_Dev' - SubmoduleCheckoutMode: '' RunDockerBuildArgs: '-o ubuntu20.04 -d cpu -x "--build_wheel"' - DoNodejsPack: 'false' - DoNugetPack: 'false' NuPackScript: '' RunInjectedPipeline: 'false' InjectedPipeline: '' DockerImageTag: '' - BuildConfig: '' - ArtifactName: 'drop-linux' TimeoutInMinutes: 120 # Controls whether unreleased onnx opsets are allowed. Default is set to 1 AllowReleasedOpsetOnly: '1' - # to inject strategy, you need to pass in the whole yaml structure - - # https://docs.microsoft.com/en-us/azure/devops/pipelines/yaml-schema?view=azure-devops&tabs=schema#strategies - # see example in orttraining-linux-gpu-ci-pipeline.yml - Strategy: '' jobs: - job: ${{ parameters.StageName }} @@ -28,16 +19,8 @@ jobs: ALLOW_RELEASED_ONNX_OPSET_ONLY: ${{ parameters.AllowReleasedOpsetOnly }} skipComponentGovernanceDetection: true pool: ${{ parameters.AgentPool }} - ${{ if ne(parameters.Strategy, '') }}: - strategy: - ${{ parameters.Strategy }} steps: - checkout: self - ${{ if ne(parameters.SubmoduleCheckoutMode, '') }}: - submodules: ${{ parameters.SubmoduleCheckoutMode }} - - task: NodeTool@0 - inputs: - versionSpec: '16.x' - template: run-docker-build-steps.yml parameters: RunDockerBuildArgs: '${{ parameters.RunDockerBuildArgs }}' @@ -48,31 +31,10 @@ jobs: searchFolder: '$(Build.BinariesDirectory)' testRunTitle: 'Unit Test Run' condition: succeededOrFailed() - - ${{ if eq(parameters['DoNugetPack'], 'true') }}: - - script: | - ${{ parameters.NuPackScript }} - displayName: 'Create Artifacts' - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline Artifact' - inputs: - artifactName: ${{ parameters.ArtifactName }} - targetPath: '$(Build.ArtifactStagingDirectory)' - - ${{ if eq(parameters['DoNodejsPack'], 'true') }}: - - script: | - npm pack - cp $(Build.SourcesDirectory)/js/node/onnxruntime-*.tgz $(Build.ArtifactStagingDirectory) - cp -R $(Build.SourcesDirectory)/js/node/prebuilds $(Build.ArtifactStagingDirectory)/prebuilds - workingDirectory: '$(Build.SourcesDirectory)/js/node' - displayName: 'Create NPM Package' - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline Artifact: ${{ parameters.ArtifactName }}' - inputs: - artifactName: ${{ parameters.ArtifactName }} - targetPath: '$(Build.ArtifactStagingDirectory)' - ${{ if eq(parameters['RunInjectedPipeline'], 'true') }}: - template: | ${{ parameters.InjectedPipeline }} parameters: DockerImageTag: ${{ parameters.DockerImageTag }} - BuildConfig: ${{ parameters.BuildConfig }} + BuildConfig: Release - template: clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml index a2ad934f7f..b6c23fdbb4 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml @@ -47,7 +47,7 @@ stages: OnnxruntimeCFlags: '-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -O3 -Wl,--strip-all' OnnxruntimeCXXFlags: '-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -O3 -Wl,--strip-all' OnnxruntimeNodejsBindingArch: 'arm64' - PoolName: 'aiinfra-linux-ARM64-CPU-2019' + PoolName: 'onnxruntime-linux-ARM64-CPU-2019' ArtifactNamePrefix: ${{ parameters.ArtifactNamePrefix }} PackageJava: ${{ parameters.PackageJava }} PackageNodeJS: ${{ parameters.PackageNodeJS }} diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml index 5b9ffac6fa..e3cfa417d8 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml @@ -46,7 +46,7 @@ stages: docker run --gpus all -e CC=/opt/rh/devtoolset-11/root/usr/bin/cc -e CXX=/opt/rh/devtoolset-11/root/usr/bin/c++ -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e NVIDIA_VISIBLE_DEVICES=all --rm --volume /data/onnx:/data/onnx:ro --volume $(Build.SourcesDirectory):/onnxruntime_src --volume $(Build.BinariesDirectory):/build \ --volume /data/models:/build/models:ro --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecuda118xtrt86build \ /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release \ - --skip_submodule_sync --parallel --build_shared_lib ${{ parameters.buildJavaOption }} --use_tensorrt --cuda_version=$(CUDA_VERSION) --cuda_home=/usr/local/cuda-$(CUDA_VERSION) --cudnn_home=/usr --tensorrt_home=/usr --cmake_extra_defines CMAKE_CUDA_HOST_COMPILER=/opt/rh/devtoolset-11/root/usr/bin/cc 'CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80' + --skip_submodule_sync --parallel --build_shared_lib ${{ parameters.buildJavaOption }} --use_tensorrt --cuda_version=$(CUDA_VERSION) --cuda_home=/usr/local/cuda-$(CUDA_VERSION) --cudnn_home=/usr --tensorrt_home=/usr --cmake_extra_defines CMAKE_CUDA_HOST_COMPILER=/opt/rh/devtoolset-11/root/usr/bin/cc 'CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80' workingDirectory: $(Build.SourcesDirectory) - ${{ if eq(parameters.buildJava, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index 93945a1cb5..4ee442a122 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -79,7 +79,7 @@ jobs: architecture: $(buildArch) - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: download-deps.yml - task: PythonScript@0 @@ -94,7 +94,6 @@ jobs: cd '$(Build.SourcesDirectory)/cmake/external/emsdk' ./emsdk install 3.1.44 ccache-git-emscripten-64bit ./emsdk activate 3.1.44 ccache-git-emscripten-64bit - ln -s $(Build.SourcesDirectory)/cmake/external/emsdk/ccache/git-emscripten_64bit/bin/ccache /usr/local/bin/ccache displayName: 'emsdk install and activate ccache for emscripten' condition: eq('${{ parameters.WithCache }}', 'true') diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml index adfcd98e37..f5e5435cfa 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml @@ -50,7 +50,7 @@ jobs: versionSpec: 3.11 - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: set-version-number-variables-step.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/publish-nuget.yml b/tools/ci_build/github/azure-pipelines/templates/publish-nuget.yml index 76fbf55331..79feae8cf5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/publish-nuget.yml +++ b/tools/ci_build/github/azure-pipelines/templates/publish-nuget.yml @@ -57,7 +57,7 @@ stages: REM use a single .csv file to put the data echo os,arch,build_config,size > $(Build.BinariesDirectory)\binary_size_data.txt 7z.exe l -slt %%~ni.zip runtimes\linux-arm64\native\libonnxruntime.so | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo linux,aarch64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt - 7z.exe l -slt %%~ni.zip runtimes\osx.10.14-x64\native\libonnxruntime.dylib | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo osx,x64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt + 7z.exe l -slt %%~ni.zip runtimes\osx-x64\native\libonnxruntime.dylib | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo osx,x64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt 7z.exe l -slt %%~ni.zip runtimes\win-x64\native\onnxruntime.dll | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo win,x64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt 7z.exe l -slt %%~ni.zip runtimes\win-x86\native\onnxruntime.dll | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo win,x86,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt ) diff --git a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml index cee3bd9c9e..8d5ca19a73 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml @@ -39,36 +39,22 @@ jobs: versionSpec: $(PythonVersion) architecture: ${{ parameters.python_arch }} - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact' - inputs: - artifactName: 'onnxruntime' - targetPath: '$(Build.BinariesDirectory)/whl' - itemPattern: ${{parameters.itemPattern}} - # The public ADO project - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: - buildType: current - # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - project: '530acbc4-21bc-487d-8cd8-348ff451d2ff' - definition: 841 - preferTriggeringPipeline: true - runVersion: 'latest' - buildType: specific + - download: build # pipeline resource identifier. + artifact: 'onnxruntime' - task: Bash@3 inputs: targetType: 'inline' script: | set -ex - files=(whl/*.whl) + files=(*.whl) FILE_NAME="${files[0]}" FILE_NAME=$(basename $FILE_NAME) PYTHON_PACKAGE_NAME=$(echo "$FILE_NAME" | cut -f 1 -d '-') - python3 -m pip install --find-links "$(Build.BinariesDirectory)/whl" $PYTHON_PACKAGE_NAME - pip show $PYTHON_PACKAGE_NAME - python -c "import onnxruntime as ort; print(ort.__version__)" - workingDirectory: $(Build.BinariesDirectory) + python3 -m pip install --find-links "$(Pipeline.Workspace)/build/onnxruntime" $PYTHON_PACKAGE_NAME + python3 -m pip show $PYTHON_PACKAGE_NAME + python3 -c "import onnxruntime as ort; print(ort.__version__)" + workingDirectory: $(Pipeline.Workspace)/build/onnxruntime displayName: Test Package Installation - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml new file mode 100644 index 0000000000..00e40b520e --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml @@ -0,0 +1,111 @@ +parameters: +- name: arch + type: string + +- name: base_image + type: string + +- name: devtoolset_rootpath + type: string + +- name: ld_library_path_arg + type: string + +- name: prepend_path + type: string + +- name: machine_pool + type: string + +- name: extra_job_id + type: string + default: '' + +- name: python_wheel_suffix + type: string + default: '' + + +# TODO: Ideally it should fetch information from the build that triggers it +- name: cmake_build_type + type: string + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + +- name: timeout + type: number + default: 120 + +jobs: +- job: Linux_Test_CPU${{ parameters.extra_job_id }}_${{ parameters.arch }} + timeoutInMinutes: ${{ parameters.timeout }} + variables: + skipComponentGovernanceDetection: true + workspace: + clean: all + pool: ${{ parameters.machine_pool }} + steps: + - checkout: self + clean: true + submodules: none + # The public ADO project + - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: + - download: current # pipeline resource identifier. + artifact: 'drop-linux-cpu-${{ parameters.arch }}' + + - download: current # pipeline resource identifier. + artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}' + + - bash: | + set -e -x + mv "$(Pipeline.Workspace)/drop-linux-cpu-${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Pipeline.Workspace)/onnxruntime${{ parameters.python_wheel_suffix }}" "$(Build.BinariesDirectory)/whl" + cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp + find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; + # The private ADO project + - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: + - download: build # pipeline resource identifier. + artifact: 'drop-linux-cpu-${{ parameters.arch }}' + + - download: build # pipeline resource identifier. + artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}' + + - bash: | + set -e -x + ls $(Pipeline.Workspace)/build + mv "$(Pipeline.Workspace)/build/drop-linux-cpu-${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Pipeline.Workspace)/build/onnxruntime${{ parameters.python_wheel_suffix }}" "$(Build.BinariesDirectory)/whl" + cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp + find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; + + # The BinSkim task uses a dotnet program which doesn't support ARM CPUs yet + - ${{ if eq(parameters.arch, 'x86_64') }}: + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '$(Build.BinariesDirectory)/tmp/**/*.so' + continueOnError: true + + - template: get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2014_cpu + Context: tools/ci_build/github/linux/docker/inference/x64/python/cpu + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u ) --build-arg BASEIMAGE=${{ parameters.base_image }} --build-arg PLATFORM=${{ parameters.arch }} --build-arg PREPEND_PATH=${{ parameters.prepend_path }} --build-arg LD_LIBRARY_PATH_ARG=${{ parameters.ld_library_path_arg }} --build-arg DEVTOOLSET_ROOTPATH=${{ parameters.devtoolset_rootpath }}" + Repository: onnxruntimecpubuildpython${{ parameters.arch }} + ${{ if eq(parameters.arch, 'aarch64') }}: + UpdateDepsTxt: false + + - task: Bash@3 + displayName: 'Bash Script' + inputs: + targetType: filePath + filePath: tools/ci_build/github/linux/run_python_dockertest.sh + arguments: -d CPU -c ${{parameters.cmake_build_type}} -i onnxruntimecpubuildpython${{ parameters.arch }} + + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml new file mode 100644 index 0000000000..c521245b33 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml @@ -0,0 +1,93 @@ +parameters: +- name: arch + type: string + +- name: device + type: string + values: + - CPU + - GPU + +- name: machine_pool + type: string + +- name: extra_job_id + type: string + default: '' + +- name: python_wheel_suffix + type: string + default: '' + + +# TODO: Ideally it should fetch information from the build that triggers it +- name: cmake_build_type + type: string + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + +- name: timeout + type: number + default: 120 + +jobs: +- job: Linux_Test_GPU${{ parameters.extra_job_id }}_${{ parameters.arch }} + timeoutInMinutes: ${{ parameters.timeout }} + variables: + skipComponentGovernanceDetection: true + workspace: + clean: all + pool: ${{ parameters.machine_pool }} + steps: + - checkout: self + clean: true + submodules: none + # The public ADO project + # - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: + + # The private ADO project + - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: + - download: build # pipeline resource identifier. + artifact: 'drop-linux-gpu-${{ parameters.arch }}' + + - download: build # pipeline resource identifier. + artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}' + + - bash: | + set -e -x + ls $(Pipeline.Workspace)/build + mv "$(Pipeline.Workspace)/build/drop-linux-gpu-${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Pipeline.Workspace)/build/onnxruntime${{ parameters.python_wheel_suffix }}" "$(Build.BinariesDirectory)/whl" + cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp + find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; + + # The BinSkim task uses a dotnet program which doesn't support ARM CPUs yet + - ${{ if eq(parameters.arch, 'x86_64') }}: + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '$(Build.BinariesDirectory)/tmp/**/*.so' + continueOnError: true + + + - template: get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_8_tensorrt8_6 + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: "--network=host --build-arg POLICY=manylinux2014 --build-arg PLATFORM=x86_64 --build-arg DEVTOOLSET_ROOTPATH=/opt/rh/devtoolset-11/root --build-arg PREPEND_PATH=/opt/rh/devtoolset-11/root/usr/bin: --build-arg LD_LIBRARY_PATH_ARG=/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst:/usr/local/lib64 --build-arg BUILD_UID=$( id -u ) --build-arg PLATFORM=${{ parameters.arch }}" + Repository: onnxruntimecuda118xtrt86build${{ parameters.arch }} + + - task: Bash@3 + displayName: 'Bash Script' + inputs: + targetType: filePath + filePath: tools/ci_build/github/linux/run_python_dockertest.sh + arguments: -d GPU -c ${{parameters.cmake_build_type}} -i onnxruntimecuda118xtrt86build${{ parameters.arch }} + + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test.yml deleted file mode 100644 index 8ddc917e85..0000000000 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test.yml +++ /dev/null @@ -1,85 +0,0 @@ -parameters: -- name: arch - type: string - -- name: device - type: string - -- name: machine_pool - type: string - -- name: extra_job_id - type: string - default: '' - -- name: python_wheel_suffix - type: string - default: '' - - -# TODO: Ideally it should fetch information from the build that triggers it -- name: cmake_build_type - type: string - default: 'Release' - values: - - Debug - - Release - - RelWithDebInfo - - MinSizeRel - -- name: timeout - type: number - default: 120 - -jobs: -- job: Linux_Test_${{ parameters.device }}${{ parameters.extra_job_id }}_${{ parameters.arch }} - timeoutInMinutes: ${{ parameters.timeout }} - variables: - skipComponentGovernanceDetection: true - workspace: - clean: all - pool: ${{ parameters.machine_pool }} - steps: - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact' - inputs: - artifactName: 'drop-linux-${{ lower(parameters.device) }}-${{ parameters.arch }}' - targetPath: '$(Build.BinariesDirectory)/${{parameters.cmake_build_type}}' - # The public ADO project - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: - buildType: current - # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - project: '530acbc4-21bc-487d-8cd8-348ff451d2ff' - definition: 841 - preferTriggeringPipeline: true - runVersion: 'latest' - buildType: specific - - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact' - inputs: - artifactName: 'onnxruntime${{ parameters.python_wheel_suffix }}' - targetPath: '$(Build.BinariesDirectory)/whl' - # The public ADO project - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: - buildType: current - # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - project: '530acbc4-21bc-487d-8cd8-348ff451d2ff' - definition: 841 - preferTriggeringPipeline: true - runVersion: 'latest' - buildType: specific - - - - task: Bash@3 - displayName: 'Bash Script' - inputs: - targetType: filePath - filePath: tools/ci_build/github/linux/run_python_tests.sh - arguments: -d ${{ parameters.device }} -c ${{parameters.cmake_build_type}} - - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 568ab6c8a8..de4896c108 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -246,24 +246,6 @@ stages: workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)' displayName: 'Run Python Tests' - #Skip it for 32 bits x86 build. Currently the scan tool has a bug: it doesn't allow me use 64 bits link.exe - #in 32 bits Win32 build. I tried all the settings but they all don't work. - - task: SDLNativeRules@3 - displayName: 'Run the PREfast SDL Native Rules for MSBuild' - condition: and (succeeded(), and(eq(variables['buildArch'], 'x64'), eq(variables['PythonVersion'], '3.8'))) - inputs: - msBuildArchitecture: amd64 - setupCommandlines: 'python $(Build.SourcesDirectory)\tools\ci_build\build.py --config Debug --build_dir $(Build.SourcesDirectory)\b --skip_submodule_sync --cmake_generator "Visual Studio 17 2022" --enable_pybind --enable_onnx_tests --parallel $(TelemetryOption) --update --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON onnxruntime_ENABLE_LTO=OFF' - msBuildCommandline: '"C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Current\Bin\amd64\msbuild.exe" "$(Build.SourcesDirectory)\b\Debug\onnxruntime.sln" /p:RunCodeAnalysis=true /p:platform="$(MsbuildPlatform)" /p:configuration=Debug /p:VisualStudioVersion="17.0" /m /p:PreferredToolArchitecture=x64' - excludedPaths: '$(Build.SourcesDirectory)\b#$(Build.SourcesDirectory)\cmake#C:\program files#C:\program files (x86)#C:\program files' - rulesetName: Custom - customRuleset: $(Build.SourcesDirectory)\cmake\Sdl.ruleset - - - task: SdtReport@2 - displayName: 'Create Security Analysis Report' - inputs: - SDLNativeRules: true - - task: TSAUpload@2 displayName: 'TSA upload' condition: and(and (succeeded(), and(eq(variables['buildArch'], 'x64'), eq(variables['PythonVersion'], '3.8'))), eq(variables['Build.SourceBranch'], 'refs/heads/main')) @@ -502,7 +484,7 @@ stages: - template: py-linux.yml parameters: arch: 'aarch64' - machine_pool: 'aiinfra-linux-ARM64-CPU-2019' + machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' base_image: 'arm64v8/centos:7' devtoolset_rootpath: /opt/rh/devtoolset-10/root ld_library_path_arg: /opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib:/opt/rh/devtoolset-10/root/usr/lib64/dyninst:/opt/rh/devtoolset-10/root/usr/lib/dyninst:/usr/local/lib64 diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml index ef938a6345..919749cac1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml @@ -22,65 +22,6 @@ parameters: default: '' jobs: -- ${{ if eq(parameters.PYTHON_VERSION, '3.8') }}: - - job: Win_py_${{ parameters.EP_NAME }}_Wheels_StaticAnalysis - timeoutInMinutes: 240 - workspace: - clean: all - pool: onnxruntime-Win-CPU-2022 - steps: - - checkout: self - clean: true - submodules: none - - task: UsePythonVersion@0 - inputs: - versionSpec: 3.8 - addToPath: true - architecture: 'x64' - - task: onebranch.pipeline.tsaoptions@1 - displayName: 'OneBranch TSAOptions' - inputs: - tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' - appendSourceBranchName: false - - - template: download-deps.yml - - - template: jobs/set-winenv.yml - parameters: - EnvSetupScript: ${{ parameters.ENV_SETUP_SCRIPT }} - DownloadCUDA: true - - - task: PythonScript@0 - displayName: 'Update deps.txt' - inputs: - scriptPath: $(Build.SourcesDirectory)/tools/ci_build/replace_urls_in_deps.py - arguments: --new_dir $(Build.BinariesDirectory)/deps - workingDirectory: $(Build.BinariesDirectory) - - - task: SDLNativeRules@3 - displayName: 'Run the PREfast SDL Native Rules for MSBuild' - inputs: - msBuildArchitecture: amd64 - setupCommandlines: 'python $(Build.SourcesDirectory)\tools\ci_build\build.py --config Debug --build_dir $(Build.SourcesDirectory)\b --skip_submodule_sync --cmake_generator "Visual Studio 17 2022" --enable_pybind ${{ parameters.BUILD_PY_PARAMETERS }} ${{ parameters.EP_BUILD_FLAGS }} --update --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON onnxruntime_ENABLE_LTO=OFF' - msBuildCommandline: '"C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Current\Bin\amd64\msbuild.exe" "$(Build.SourcesDirectory)\b\Debug\onnxruntime.sln" /p:RunCodeAnalysis=true /p:platform=x64 /p:configuration=Debug /p:VisualStudioVersion="17.0" /m /p:PreferredToolArchitecture=x64' - excludedPaths: '$(Build.SourcesDirectory)\b#$(Build.SourcesDirectory)\cmake#C:\program files#C:\program files (x86)#C:\program files' - rulesetName: Custom - customRuleset: $(Build.SourcesDirectory)\cmake\Sdl.ruleset - publishXML: true - - - task: SdtReport@2 - displayName: 'Create Security Analysis Report' - inputs: - SDLNativeRules: true - - - task: TSAUpload@2 - displayName: 'TSA upload' - condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) - inputs: - GdnPublishTsaOnboard: false - GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' - - - job: Win_py_${{ parameters.EP_NAME }}_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }} timeoutInMinutes: 240 workspace: diff --git a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml index 8c54e71448..e63939ae01 100644 --- a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml @@ -80,7 +80,7 @@ stages: - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - script: brew install coreutils ninja npm yarn diff --git a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml index 15254ce4d1..8bb3026520 100644 --- a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml @@ -116,7 +116,8 @@ stages: xcodeDeveloperDir: '/Applications/Xcode_${{ variables.xcodeVersion }}.app/Contents/Developer' signingOption: 'manual' signingIdentity: '$(APPLE_CERTIFICATE_SIGNING_IDENTITY)' - provisioningProfileName: 'iOS Team Provisioning Profile' + provisioningProfileName: 'temporary *' # temporary name, change it back to the original below later + #provisioningProfileName: 'iOS Team Provisioning Profile' args: '-derivedDataPath $(Build.BinariesDirectory)/app_center_test/ios_package_test/DerivedData' workingDirectory: '$(Build.BinariesDirectory)/app_center_test/ios_package_test/' displayName: 'Build App Center iPhone arm64 tests' diff --git a/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml b/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml index 4494fd36b3..96e6ff89cd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml @@ -29,7 +29,7 @@ jobs: - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - task: DownloadPipelineArtifact@2 inputs: patterns: 'Release_*/**/*' diff --git a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml index e0a85cc197..55535173bc 100644 --- a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml @@ -74,7 +74,7 @@ stages: displayName: 'Checkout submodule onnx' - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: linux-web-init-and-check.yml - task: Bash@3 displayName: 'Extract commit SHA and save to __commit.txt' diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index f6da7bb857..8d28b4ce58 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -101,7 +101,7 @@ stages: - task: NodeTool@0 condition: and(succeeded(), eq('${{ parameters.buildNodejs}}', true)) inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: jobs/set-winenv.yml parameters: @@ -263,25 +263,6 @@ stages: AnalyzeTargetGlob: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\**\*.dll' continueOnError: true - - task: DeleteFiles@1 - displayName: 'Delete files from $(Build.BinariesDirectory)\RelWithDebInfo' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' - Contents: | - **/*.obj - **/*.pdb - **/*.dll - - #Manually set msBuildCommandline so that we can also set CAExcludePath - - task: SDLNativeRules@3 - displayName: 'Run the PREfast SDL Native Rules for MSBuild' - condition: and (succeeded(), eq(variables['msbuildPlatform'], 'x64')) - inputs: - msBuildArchitecture: amd64 - setupCommandlines: 'python $(Build.SourcesDirectory)\tools\ci_build\build.py --config Debug --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }} --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON' - msBuildCommandline: '"C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Current\Bin\amd64\msbuild.exe" "$(Build.BinariesDirectory)\Debug\onnxruntime.sln" /p:platform="$(MsbuildPlatform)" /p:configuration=Debug /p:VisualStudioVersion="17.0" /m /p:PreferredToolArchitecture=x64' - excludedPaths: '$(Build.BinariesDirectory)#$(Build.SourcesDirectory)\cmake#C:\program files (x86)' - - task: PostAnalysis@2 inputs: GdnBreakAllTools: false diff --git a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml index 9d36e2dbe4..406683af80 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml @@ -74,7 +74,7 @@ jobs: architecture: $(buildArch) - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: download-deps.yml - task: PythonScript@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index 713396dd64..90fc30141a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -72,7 +72,7 @@ jobs: displayName: 'Testing: force EOL to lf on windows for /js/**' - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - task: DownloadPipelineArtifact@2 inputs: patterns: '${{ parameters.BuildConfig }}_*/**/*' diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml index 7235673895..f7876f1502 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml @@ -33,7 +33,7 @@ jobs: displayName: 'Checkout submodule onnx' - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - task: DownloadPipelineArtifact@2 inputs: patterns: 'Release_*/**/*' diff --git a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml index 0eb78412de..58c0b8d353 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml @@ -27,7 +27,7 @@ jobs: - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - task: BatchScript@1 displayName: 'setup env' diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml index 7f71f41484..b7e3ce7940 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml @@ -44,7 +44,6 @@ stages: isX86: false job_name_suffix: x64_debug RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false @@ -66,7 +65,6 @@ stages: isX86: false job_name_suffix: x64_release RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false @@ -86,7 +84,6 @@ stages: isX86: false job_name_suffix: x64_release RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: DNNL GenerateDocumentation: false @@ -108,7 +105,6 @@ stages: isX86: false job_name_suffix: x64_release RunOnnxRuntimeTests: true - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: XNNPACK GenerateDocumentation: false @@ -129,7 +125,6 @@ stages: job_name_suffix: x64_release_winml RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} # WinML has many warnings - RunStaticCodeAnalysis: false EnablePython: false isTraining: false ORT_EP_NAME: CPU @@ -150,7 +145,6 @@ stages: isX86: true job_name_suffix: x86_release RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false @@ -170,7 +164,6 @@ stages: isX86: false job_name_suffix: training_x64_debug RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false isTraining: true ORT_EP_NAME: CPU GenerateDocumentation: false @@ -190,7 +183,6 @@ stages: isX86: false job_name_suffix: training_x64_release RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: true isTraining: true ORT_EP_NAME: CPU GenerateDocumentation: false @@ -210,7 +202,6 @@ stages: isX86: false job_name_suffix: ort_training_apis_x64_release RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false EnablePython: false isTraining: true ORT_EP_NAME: CPU @@ -231,7 +222,6 @@ stages: isX86: false job_name_suffix: x64_release_azure RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false EnablePython: false isTraining: false ORT_EP_NAME: CPU diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml index 7ab55a5d80..806ed797f8 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml @@ -49,7 +49,6 @@ stages: isX86: false job_name_suffix: x64_RelWithDebInfo RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false ORT_EP_NAME: CUDA WITH_CACHE: true MachinePool: onnxruntime-Win2022-GPU-A10 @@ -67,7 +66,6 @@ stages: isX86: false job_name_suffix: x64_RelWithDebInfo RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false ORT_EP_NAME: CUDA WITH_CACHE: true # Some unit tests crash on A10 GPUs. So this job still needs to use A10. @@ -87,7 +85,6 @@ stages: isX86: false job_name_suffix: x64_RelWithDebInfo RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false ORT_EP_NAME: DML WITH_CACHE: true MachinePool: onnxruntime-Win2022-GPU-dml-A10 @@ -106,7 +103,6 @@ stages: isX86: false job_name_suffix: x64_RelWithDebInfo RunOnnxRuntimeTests: false - RunStaticCodeAnalysis: false GenerateDocumentation: true ORT_EP_NAME: CUDA # It doesn't really matter which EP is selected here since this stage is for documentation. WITH_CACHE: true diff --git a/tools/ci_build/github/linux/build_cuda_c_api_package.sh b/tools/ci_build/github/linux/build_cuda_c_api_package.sh index 271f010a9d..5ce1c2d2e8 100755 --- a/tools/ci_build/github/linux/build_cuda_c_api_package.sh +++ b/tools/ci_build/github/linux/build_cuda_c_api_package.sh @@ -1,10 +1,13 @@ #!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + export CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" export CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" docker run --gpus all -e CFLAGS -e CXXFLAGS -e NVIDIA_VISIBLE_DEVICES=all --rm --volume \ $BUILD_SOURCESDIRECTORY:/onnxruntime_src --volume $BUILD_BINARIESDIRECTORY:/build \ --volume /data/models:/build/models:ro --volume /data/onnx:/data/onnx:ro -e NIGHTLY_BUILD onnxruntimecuda11centosbuild \ python3 /onnxruntime_src/tools/ci_build/build.py --build_java --build_dir /build --config Release \ ---skip_submodule_sync --parallel --build_shared_lib --use_cuda --cuda_version=$CUDA_VERSION \ +--skip_submodule_sync --parallel --nvcc_threads=1 --build_shared_lib --use_cuda --cuda_version=$CUDA_VERSION \ --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr/local/cuda-$CUDA_VERSION \ ---cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80' +--cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80' diff --git a/tools/ci_build/github/linux/build_linux_arm64_python_package.sh b/tools/ci_build/github/linux/build_linux_arm64_python_package.sh index 58d7d32ac4..a1a0d428c6 100755 --- a/tools/ci_build/github/linux/build_linux_arm64_python_package.sh +++ b/tools/ci_build/github/linux/build_linux_arm64_python_package.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + set -e -x # This script invokes build.py @@ -62,7 +65,7 @@ fi if [ "$BUILD_DEVICE" == "GPU" ]; then #Enable CUDA and TRT EPs. ONNXRUNTIME_CUDA_VERSION="11.8" - BUILD_ARGS+=("--use_cuda" "--use_tensorrt" "--cuda_version=$ONNXRUNTIME_CUDA_VERSION" "--tensorrt_home=/usr" "--cuda_home=/usr/local/cuda-$ONNXRUNTIME_CUDA_VERSION" "--cudnn_home=/usr/local/cuda-$ONNXRUNTIME_CUDA_VERSION" "--cmake_extra_defines" "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80") + BUILD_ARGS+=("--nvcc_threads=1" "--use_cuda" "--use_tensorrt" "--cuda_version=$ONNXRUNTIME_CUDA_VERSION" "--tensorrt_home=/usr" "--cuda_home=/usr/local/cuda-$ONNXRUNTIME_CUDA_VERSION" "--cudnn_home=/usr/local/cuda-$ONNXRUNTIME_CUDA_VERSION" "--cmake_extra_defines" "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80") fi export CFLAGS diff --git a/tools/ci_build/github/linux/build_yocto.sh b/tools/ci_build/github/linux/build_yocto.sh index e948a105c0..fab5173353 100755 --- a/tools/ci_build/github/linux/build_yocto.sh +++ b/tools/ci_build/github/linux/build_yocto.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + set -e -o -x SCRIPT_DIR="$( dirname "${BASH_SOURCE[0]}" )" TARGET_FOLDER="/datadrive/ARM" diff --git a/tools/ci_build/github/linux/copy_strip_binary.sh b/tools/ci_build/github/linux/copy_strip_binary.sh index 5832b0ec2e..af8f0ecb9c 100755 --- a/tools/ci_build/github/linux/copy_strip_binary.sh +++ b/tools/ci_build/github/linux/copy_strip_binary.sh @@ -1,4 +1,6 @@ #!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. set -e -o -x while getopts r:a:l:c:s:t: parameter_Option @@ -44,6 +46,7 @@ fi cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_c_api.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_cxx_api.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_cxx_inline.h $BINARY_DIR/$ARTIFACT_NAME/include +cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_float16.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/providers/cpu/cpu_provider_factory.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h $BINARY_DIR/$ARTIFACT_NAME/include diff --git a/tools/ci_build/github/linux/create_package.sh b/tools/ci_build/github/linux/create_package.sh index ed012a5abc..305d261838 100755 --- a/tools/ci_build/github/linux/create_package.sh +++ b/tools/ci_build/github/linux/create_package.sh @@ -1,4 +1,6 @@ #!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. set -e SCRIPT=`realpath $0` SCRIPT_DIR=`dirname $SCRIPT` diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cpu index 033afde6aa..561df220af 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cpu @@ -26,7 +26,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -35,7 +34,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -137,9 +135,7 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ - build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ build_scripts/requirements3.10.txt \ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11 index dc52fb51d6..8a092c437a 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -140,7 +138,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_6_tensorrt8_4 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_6_tensorrt8_4 index 303e83eb23..68b779e6f1 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_6_tensorrt8_4 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_6_tensorrt8_4 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -140,7 +138,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_6_tensorrt8_5 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_6_tensorrt8_5 index d17e4b2458..dfc9e819ad 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_6_tensorrt8_5 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_6_tensorrt8_5 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -140,7 +138,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_8_tensorrt8_6 index bcdc24d5eb..6e27db4eb3 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_8_tensorrt8_6 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -147,7 +145,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.7.txt \ build_scripts/requirements3.8.txt \ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_rocm b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_rocm index 9f7575d62e..036d261044 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_rocm +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_rocm @@ -52,7 +52,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -61,7 +60,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -164,7 +162,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_training_cuda11_8 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_training_cuda11_8 index 5d77446007..c3c7213212 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_training_cuda11_8 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_training_cuda11_8 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -140,7 +138,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 new file mode 100644 index 0000000000..cdf504c8e3 --- /dev/null +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 @@ -0,0 +1,45 @@ +# -------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------- +# Dockerfile to Test ONNX Runtime on UBI8 with CUDA 11.8 and TensorRT 8.6 + +# Build base image with required system packages +FROM nvidia/cuda:11.8.0-cudnn8-devel-ubi8 AS base + +ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} + +RUN dnf install -y bash wget &&\ + dnf clean dbcache + +# Install python3 +RUN dnf install -y \ + python3.8 \ + python38-pip \ + python38-wheel &&\ + cd /usr/local/bin &&\ + ln -s /usr/bin/python3 python3.8 &&\ + ln -s /usr/bin/pip3 pip3.8; + +RUN pip3 install --upgrade pip +RUN pip3 install setuptools>=41.0.0 + +# Install TensorRT +RUN dnf install -y libnvinfer8 libnvonnxparsers8 libnvparsers8 libnvinfer-plugin8 libnvinfer-lean8 libnvinfer-vc-plugin8 libnvinfer-dispatch8 +RUN v="8.6.1.6-1+cuda11.8" &&\ + dnf downgrade -y libnvinfer8-${v} libnvinfer8-${v} libnvonnxparsers8-${v} libnvparsers8-${v} libnvinfer-plugin8-${v} libnvinfer-lean8-${v} libnvinfer-vc-plugin8-${v} libnvinfer-dispatch8-${v} &&\ + dnf install -y dnf-plugin-versionlock &&\ + dnf versionlock libnvinfer8 libnvonnxparsers8 libnvparsers8 libnvinfer-plugin8 libnvinfer-lean8 libnvinfer-vc-plugin8 libnvinfer-dispatch8 +RUN dnf clean dbcache + + +ADD scripts /tmp/scripts +RUN cd /tmp/scripts && /tmp/scripts/install_dotnet.sh && /tmp/scripts/install_java.sh && rm -rf /tmp/scripts + +# Build final image from base. +FROM base as final +ARG BUILD_USER=onnxruntimedev +ARG BUILD_UID=1000 +RUN adduser --uid $BUILD_UID $BUILD_USER +WORKDIR /home/$BUILD_USER +USER $BUILD_USER diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 new file mode 100644 index 0000000000..c211fa9b9e --- /dev/null +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 @@ -0,0 +1,53 @@ +# -------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------- +# Dockerfile to run ONNXRuntime with TensorRT integration + +# Build base image with required system packages +FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 AS base + +# The local directory into which to build and install CMAKE +ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code + +ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.27.3-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update &&\ + apt-get install -y sudo git bash unattended-upgrades wget +RUN unattended-upgrade + +# Install python3 +RUN apt-get install -y --no-install-recommends \ + python3 \ + python3-pip \ + python3-dev \ + python3-wheel &&\ + cd /usr/local/bin &&\ + ln -s /usr/bin/python3 python &&\ + ln -s /usr/bin/pip3 pip; + +RUN pip install --upgrade pip +RUN pip install setuptools>=41.0.0 + +# Install TensorRT +RUN v="8.6.1.6-1+cuda11.8" &&\ + apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub &&\ + apt-get update &&\ + sudo apt-get install -y libnvinfer8=${v} libnvonnxparsers8=${v} libnvparsers8=${v} libnvinfer-plugin8=${v} libnvinfer-lean8=${v} libnvinfer-vc-plugin8=${v} libnvinfer-dispatch8=${v}\ + libnvinfer-headers-dev=${v} libnvinfer-headers-plugin-dev=${v} libnvinfer-dev=${v} libnvonnxparsers-dev=${v} libnvparsers-dev=${v} libnvinfer-plugin-dev=${v} libnvinfer-lean-dev=${v} libnvinfer-vc-plugin-dev=${v} libnvinfer-dispatch-dev=${v}\ + python3-libnvinfer=${v} libnvinfer-samples=${v} tensorrt-dev=${v} tensorrt-libs=${v} + +# Install Valgrind +RUN apt-get install -y valgrind + +ADD scripts /tmp/scripts +RUN cd /tmp/scripts && /tmp/scripts/install_dotnet.sh && rm -rf /tmp/scripts + +# Build final image from base. +FROM base as final +ARG BUILD_USER=onnxruntimedev +ARG BUILD_UID=1000 +RUN adduser --uid $BUILD_UID $BUILD_USER +WORKDIR /home/$BUILD_USER +USER $BUILD_USER diff --git a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2014_cpu b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2014_cpu index 8869a78902..691e45e743 100644 --- a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2014_cpu +++ b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2014_cpu @@ -26,7 +26,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -35,7 +34,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -132,7 +130,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ diff --git a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt index c0c6505ca0..8a9c4dac1d 100644 --- a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt @@ -6,5 +6,5 @@ setuptools>=41.4.0 wheel git+http://github.com/onnx/onnx.git@e2525550194ce3d8a2c4a3af451c9d9b3ae6650e#egg=onnx protobuf==3.20.2 -sympy==1.10.1 +sympy==1.12 flatbuffers diff --git a/tools/ci_build/github/linux/docker/manylinux.patch b/tools/ci_build/github/linux/docker/manylinux.patch index 7750118d01..1a92b4c094 100644 --- a/tools/ci_build/github/linux/docker/manylinux.patch +++ b/tools/ci_build/github/linux/docker/manylinux.patch @@ -50,6 +50,17 @@ index 961e34d..55ae11b 100755 make install > /dev/null } +diff --git a/finalize.sh b/finalize.sh +index 621eab9..4cbcf90 100755 +--- a/finalize.sh ++++ b/finalize.sh +@@ -86,6 +86,3 @@ clean_pyc /opt/_internal + rm -rf /root/.cache + + hardlink -cv /opt/_internal +- +-# update system packages +-LC_ALL=C ${MY_DIR}/update-system-packages.sh diff --git a/install-entrypoint.sh b/install-entrypoint.sh index 9ef1e99..ec52833 100755 --- a/install-entrypoint.sh @@ -65,7 +76,7 @@ index 9ef1e99..ec52833 100755 +fi \ No newline at end of file diff --git a/install-runtime-packages.sh b/install-runtime-packages.sh -index 137d2e2..21b60a7 100755 +index 137d2e2..7a17e16 100755 --- a/install-runtime-packages.sh +++ b/install-runtime-packages.sh @@ -73,9 +73,11 @@ if [ "${AUDITWHEEL_POLICY}" == "manylinux2014" ]; then @@ -83,3 +94,15 @@ index 137d2e2..21b60a7 100755 elif [ "${AUDITWHEEL_ARCH}" == "aarch64" ] || [ "${AUDITWHEEL_ARCH}" == "ppc64le" ] || [ "${AUDITWHEEL_ARCH}" == "s390x" ]; then # Software collection (for devtoolset-10) yum -y install centos-release-scl-rh +@@ -121,11 +123,6 @@ else + exit 1 + fi + +-# update system packages, we already updated them but +-# the following script takes care of cleaning-up some things +-# and since it's also needed in the finalize step, everything's +-# centralized in this script to avoid code duplication +-LC_ALL=C ${MY_DIR}/update-system-packages.sh + + if [ "${BASE_POLICY}" == "manylinux" ]; then + # we'll be removing libcrypt.so.1 later on diff --git a/tools/ci_build/github/linux/docker/scripts/install_dotnet.sh b/tools/ci_build/github/linux/docker/scripts/install_dotnet.sh new file mode 100755 index 0000000000..d89a5e84c1 --- /dev/null +++ b/tools/ci_build/github/linux/docker/scripts/install_dotnet.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +set -e -x + +if [ -f /etc/redhat-release ]; then + rpm -Uvh https://packages.microsoft.com/config/centos/7/packages-microsoft-prod.rpm + yum install -y dotnet-sdk-6.0 +elif [ -f /etc/os-release ]; then + # Get Ubuntu version + declare repo_version=$(if command -v lsb_release &> /dev/null; then lsb_release -r -s; else grep -oP '(?<=^VERSION_ID=).+' /etc/os-release | tr -d '"'; fi) + # Download Microsoft signing key and repository + wget https://packages.microsoft.com/config/ubuntu/$repo_version/packages-microsoft-prod.deb -O packages-microsoft-prod.deb + # Install Microsoft signing key and repository + dpkg -i packages-microsoft-prod.deb + # Clean up + rm packages-microsoft-prod.deb + # Update packages + apt-get update && apt-get install -y dotnet-sdk-6.0 +else + echo "Unsupported OS" + exit 1 +fi diff --git a/tools/ci_build/github/linux/docker/scripts/install_java.sh b/tools/ci_build/github/linux/docker/scripts/install_java.sh new file mode 100755 index 0000000000..d11e29f693 --- /dev/null +++ b/tools/ci_build/github/linux/docker/scripts/install_java.sh @@ -0,0 +1,12 @@ +#!/bin/bash +set -e -x + +if [ -f /etc/redhat-release ]; then + dnf install -y java-11-openjdk-devel \ + && dnf clean dbcache +elif [ -f /etc/os-release ]; then + apt-get update && apt-get install -y openjdk-11-jdk +else + echo "Unsupported OS" + exit 1 +fi diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps.sh index c34abbd2ba..e569e58d54 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps.sh @@ -3,7 +3,7 @@ set -e -x # Development tools and libraries if [ -f /etc/redhat-release ]; then - yum update && yum -y install graphviz + yum -y install graphviz os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) elif [ -f /etc/os-release ]; then apt-get update && apt-get install -y graphviz @@ -13,6 +13,9 @@ else exit 1 fi +# Install dotnet +source $(cd "$(dirname "${BASH_SOURCE[0]}")/.." &> /dev/null && pwd)/install_dotnet.sh + if [ ! -d "/opt/conda/bin" ]; then PYTHON_EXES=("/opt/python/cp38-cp38/bin/python3.8" "/opt/python/cp39-cp39/bin/python3.9" "/opt/python/cp310-cp310/bin/python3.10" "/opt/python/cp311-cp311/bin/python3.11") else diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt index c8ff7a804e..6b8003c01c 100644 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt @@ -6,6 +6,6 @@ setuptools>=41.4.0 wheel git+http://github.com/onnx/onnx.git@e2525550194ce3d8a2c4a3af451c9d9b3ae6650e#egg=onnx protobuf==3.20.2 -sympy==1.10.1 +sympy==1.12 flatbuffers neural-compressor>=2.2.1 diff --git a/tools/ci_build/github/linux/docker/scripts/requirements.txt b/tools/ci_build/github/linux/docker/scripts/requirements.txt index 2248652c98..9dbe856753 100644 --- a/tools/ci_build/github/linux/docker/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/requirements.txt @@ -7,7 +7,7 @@ setuptools>=41.4.0 wheel>=0.35.1 git+http://github.com/onnx/onnx.git@e2525550194ce3d8a2c4a3af451c9d9b3ae6650e#egg=onnx argparse -sympy==1.10.1 +sympy==1.12 flatbuffers protobuf==3.20.2 packaging diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt index 202d43befc..891291b6fa 100644 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt @@ -2,7 +2,7 @@ pandas scikit-learn numpy==1.21.6 ; python_version < '3.11' numpy==1.24.2 ; python_version >= '3.11' -transformers==v4.4.2 +transformers==v4.16.1 rsa==4.9 tensorboard>=2.2.0,<2.5.0 h5py diff --git a/tools/ci_build/github/linux/extract_and_bundle_gpu_package.sh b/tools/ci_build/github/linux/extract_and_bundle_gpu_package.sh index 9492b7bcf5..a442a93203 100755 --- a/tools/ci_build/github/linux/extract_and_bundle_gpu_package.sh +++ b/tools/ci_build/github/linux/extract_and_bundle_gpu_package.sh @@ -1,4 +1,6 @@ #!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. set -e -o -x while getopts a: parameter_Option diff --git a/tools/ci_build/github/linux/java_copy_strip_binary.sh b/tools/ci_build/github/linux/java_copy_strip_binary.sh index 329c1b0ab9..8004e37a73 100755 --- a/tools/ci_build/github/linux/java_copy_strip_binary.sh +++ b/tools/ci_build/github/linux/java_copy_strip_binary.sh @@ -1,4 +1,6 @@ #!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. set -e -o -x while getopts r:a:l:n:c:h:v: parameter_Option diff --git a/tools/ci_build/github/linux/run_python_dockerbuild.sh b/tools/ci_build/github/linux/run_python_dockerbuild.sh index 18ac648282..2ab01eacad 100755 --- a/tools/ci_build/github/linux/run_python_dockerbuild.sh +++ b/tools/ci_build/github/linux/run_python_dockerbuild.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + set -e -x BUILD_CONFIG="Release" diff --git a/tools/ci_build/github/linux/run_python_dockertest.sh b/tools/ci_build/github/linux/run_python_dockertest.sh new file mode 100755 index 0000000000..7b080a9047 --- /dev/null +++ b/tools/ci_build/github/linux/run_python_dockertest.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +set -e -x +BUILD_CONFIG="Release" + +while getopts "i:d:x:c:" parameter_Option +do case "${parameter_Option}" +in +i) DOCKER_IMAGE=${OPTARG};; +d) DEVICE=${OPTARG};; +c) BUILD_CONFIG=${OPTARG};; +esac +done + +if [ $DEVICE = "GPU" ]; then + ADDITIONAL_DOCKER_PARAMETER="--gpus all" +fi + +mkdir -p $HOME/.onnx +docker run --rm \ + --volume /data/onnx:/data/onnx:ro \ + --volume $BUILD_SOURCESDIRECTORY:/onnxruntime_src \ + --volume $BUILD_BINARIESDIRECTORY:/build \ + --volume /data/models:/build/models:ro \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + -w /onnxruntime_src \ + -e NIGHTLY_BUILD \ + -e BUILD_BUILDNUMBER \ + $ADDITIONAL_DOCKER_PARAMETER \ + $DOCKER_IMAGE tools/ci_build/github/linux/run_python_tests.sh -d $DEVICE -c $BUILD_CONFIG diff --git a/tools/ci_build/github/linux/run_python_tests.sh b/tools/ci_build/github/linux/run_python_tests.sh index 90362a3315..cecb790b19 100755 --- a/tools/ci_build/github/linux/run_python_tests.sh +++ b/tools/ci_build/github/linux/run_python_tests.sh @@ -15,7 +15,8 @@ c) BUILD_CONFIG=${OPTARG};; esac done -cd $BUILD_BINARIESDIRECTORY +export PATH=/opt/python/cp38-cp38/bin:$PATH +cd /build files=(whl/*.whl) FILE_NAME="${files[0]}" FILE_NAME=$(basename $FILE_NAME) @@ -23,7 +24,7 @@ PYTHON_PACKAGE_NAME=$(echo "$FILE_NAME" | cut -f 1 -d '-') echo "Package name:$PYTHON_PACKAGE_NAME" -BUILD_ARGS="--build_dir $BUILD_BINARIESDIRECTORY --config $BUILD_CONFIG --test --skip_submodule_sync --parallel --enable_lto --build_wheel " +BUILD_ARGS="--build_dir /build --config $BUILD_CONFIG --test --skip_submodule_sync --parallel --enable_lto --build_wheel " ARCH=$(uname -m) @@ -34,20 +35,15 @@ fi if [ $BUILD_DEVICE == "GPU" ]; then BUILD_ARGS="$BUILD_ARGS --use_cuda --use_tensorrt --cuda_version=11.8 --tensorrt_home=/usr --cuda_home=/usr/local/cuda-11.8 --cudnn_home=/usr/local/cuda-11.8" fi -# We assume the machine doesn't have gcc and python development header files, so we don't build onnxruntime from source -sudo rm -rf /build /onnxruntime_src -sudo ln -s $BUILD_SOURCESDIRECTORY /onnxruntime_src -python3 -m pip uninstall -y $PYTHON_PACKAGE_NAME ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml ort-nightly-directml onnx -qq +python3 -m pip install --upgrade pip # Install the packages that are needed for installing the onnxruntime python package -python3 -m pip install -r $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/requirements.txt +python3 -m pip install -r /build/$BUILD_CONFIG/requirements.txt # Install the packages that are needed for running test scripts -# Install the latest ONNX release which may contain not fixed bugs. However, it is what most people use. -python3 -m pip install onnx pytest +python3 -m pip install pytest # The "--no-index" flag is crucial. The local whl folder is just an additional source. Pypi's doc says "there is no # ordering in the locations that are searched" if we don't disable the default one with "--no-index" -python3 -m pip install --no-index --find-links $BUILD_BINARIESDIRECTORY/whl $PYTHON_PACKAGE_NAME -ln -s /data/models $BUILD_BINARIESDIRECTORY -cd $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG +python3 -m pip install --no-index --find-links /build/whl $PYTHON_PACKAGE_NAME +cd /build/$BUILD_CONFIG # Restore file permissions xargs -a perms.txt chmod a+x -python3 $BUILD_SOURCESDIRECTORY/tools/ci_build/build.py $BUILD_ARGS --ctest_path '' +python3 /onnxruntime_src/tools/ci_build/build.py $BUILD_ARGS --ctest_path '' diff --git a/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh b/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh index 56f5ff9f9e..ee4339c24d 100755 --- a/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh +++ b/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh @@ -1,4 +1,6 @@ #!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. pip3 install --user --upgrade pip diff --git a/tools/ci_build/github/linux/upload_ortsrv_binaries.sh b/tools/ci_build/github/linux/upload_ortsrv_binaries.sh index 9d6f9406e4..dcbea27847 100755 --- a/tools/ci_build/github/linux/upload_ortsrv_binaries.sh +++ b/tools/ci_build/github/linux/upload_ortsrv_binaries.sh @@ -1,4 +1,6 @@ #!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. set -e -o -x while getopts a:r:i:c:p:b: parameter_Option diff --git a/tools/ci_build/github/linux/yocto_build_toolchain.sh b/tools/ci_build/github/linux/yocto_build_toolchain.sh index 26d4f58348..a4e7fe7a38 100755 --- a/tools/ci_build/github/linux/yocto_build_toolchain.sh +++ b/tools/ci_build/github/linux/yocto_build_toolchain.sh @@ -1,4 +1,6 @@ #!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. set -e YOCTO_VERSION="4.19" diff --git a/tools/ci_build/requirements.txt b/tools/ci_build/requirements.txt index 620da1afa1..96659d70af 100644 --- a/tools/ci_build/requirements.txt +++ b/tools/ci_build/requirements.txt @@ -1,7 +1,8 @@ -# packages used by transformers tool test +# packages used by transformers python unittest (only enabled in Linux CPU CI Pipeline) packaging protobuf==3.20.2 numpy==1.24.0 coloredlogs==15.0 transformers==4.30.0 psutil +einops \ No newline at end of file diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index 9dc36633a5..3aba1d0577 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -67,7 +67,7 @@ def generate_file_list_for_ep(nuget_artifacts_dir, ep, files_list, include_pdbs, is_versioned_dylib = re.match(r".*[\.\d+]+\.dylib$", child_file.name) if child_file.is_file() and child_file.suffix == ".dylib" and not is_versioned_dylib: files_list.append( - '' % cpu_arch + '' % cpu_arch ) for cpu_arch in ["x64", "aarch64"]: if child.name == get_package_name("linux", cpu_arch, ep, is_training_package): diff --git a/winml/test/model/model_tests.cpp b/winml/test/model/model_tests.cpp index 43b904ce77..b2dd331cce 100644 --- a/winml/test/model/model_tests.cpp +++ b/winml/test/model/model_tests.cpp @@ -379,13 +379,6 @@ std::string GetFullNameOfTest(ITestCase* testCase, winml::LearningModelDeviceKin name += tokenizedModelPath[tokenizedModelPath.size() - 2] += "_"; // model name name += tokenizedModelPath[tokenizedModelPath.size() - 3]; // opset version - // To introduce models from model zoo, the model path is structured like this "///?.onnx" - std::string source = tokenizedModelPath[tokenizedModelPath.size() - 4]; - // `models` means the root of models, to be ompatible with the old structure, that is, the source name is empty. - if (source != "models") { - name += "_" + source; - } - std::replace_if( name.begin(), name.end(), [](char c) { return !google::protobuf::ascii_isalnum(c); }, '_' ); @@ -404,6 +397,13 @@ std::string GetFullNameOfTest(ITestCase* testCase, winml::LearningModelDeviceKin ModifyNameIfDisabledTest(/*inout*/ name, deviceKind); } + // To introduce models from model zoo, the model path is structured like this "///?.onnx" + std::string source = tokenizedModelPath[tokenizedModelPath.size() - 4]; + // `models` means the root of models, to be ompatible with the old structure, that is, the source name is empty. + if (source != "models") { + name += "_" + source; + } + return name; } diff --git a/winml/test/model/skip_model_tests.h b/winml/test/model/skip_model_tests.h index 174f57143e..9d66320343 100644 --- a/winml/test/model/skip_model_tests.h +++ b/winml/test/model/skip_model_tests.h @@ -146,6 +146,8 @@ std::unordered_map disabledTests({ "Bug 31005780: Result of fp16_test_tiny_yolov2_opset7 and fp16_coreml_FNS_Candy_opset7 models on DirectML aren't as accurate as on CPU https://microsoft.visualstudio.com/OS/_workitems/edit/31005780"}, { "mlperf_ssd_mobilenet_300_opset10_GPU", "Bug 31005624: mlperf_ssd_mobilenet_300 opset 10 model fails to evaluate in DirectML https://microsoft.visualstudio.com/OS/_workitems/edit/31005624" }, + { "mlperf_ssd_resnet34_1200_opset10_GPU", + "Bug 31005624: mlperf_ssd_resnet34_1200_opset10_GPU opset 10 model fails to evaluate in DirectML https://microsoft.visualstudio.com/OS/_workitems/edit/31005624" }, }); /* @@ -161,10 +163,8 @@ std::unordered_map> disabledGpu test name -> absolute difference sampleTolerance */ std::unordered_map sampleTolerancePerTests({ - {"fp16_inception_v1_opset7_GPU",0.005 }, - {"fp16_inception_v1_opset8_GPU", 0.005}, - { "candy_opset9_GPU", - 0.00150000 }, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/ - { "fp16_tiny_yolov2_opset8_GPU", - 0.109000 }, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/ + {"fp16_inception_v1_opset7_GPU", 0.005}, + {"fp16_inception_v1_opset8_GPU", 0.005}, + { "candy_opset9_GPU", 0.00150000}, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/ + { "fp16_tiny_yolov2_opset8_GPU", 0.109000}, // Intel(R) UHD Graphics 630 (29.20.100.9020) AP machine has inaccurate GPU results for FNS Candy opset 9 https://microsoft.visualstudio.com/OS/_workitems/edit/30696168/ });