diff --git a/.gdn/.gdntsa b/.gdn/.gdntsa index 2992cab431..e49848d116 100644 --- a/.gdn/.gdntsa +++ b/.gdn/.gdntsa @@ -1,3 +1,3 @@ { - "codebaseName": "onnxruntime_master" + "codebaseName": "onnxruntime_main" } \ No newline at end of file diff --git a/.github/workflows/generate-skip-doc-change.py b/.github.upstream/generate-skip-doc-change.py similarity index 100% rename from .github/workflows/generate-skip-doc-change.py rename to .github.upstream/generate-skip-doc-change.py diff --git a/.github/workflows/skip-doc-change.yml.j2 b/.github.upstream/skip-doc-change.yml.j2 similarity index 100% rename from .github/workflows/skip-doc-change.yml.j2 rename to .github.upstream/skip-doc-change.yml.j2 diff --git a/.github/workflows/cffconvert.yml b/.github.upstream/workflows/cffconvert.yml similarity index 100% rename from .github/workflows/cffconvert.yml rename to .github.upstream/workflows/cffconvert.yml diff --git a/.github/workflows/codeql.yml b/.github.upstream/workflows/codeql.yml similarity index 100% rename from .github/workflows/codeql.yml rename to .github.upstream/workflows/codeql.yml diff --git a/.github/workflows/generated_fake_win_gpu_ci.yml b/.github.upstream/workflows/generated_fake_win_gpu_ci.yml similarity index 100% rename from .github/workflows/generated_fake_win_gpu_ci.yml rename to .github.upstream/workflows/generated_fake_win_gpu_ci.yml diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github.upstream/workflows/gradle-wrapper-validation.yml similarity index 100% rename from .github/workflows/gradle-wrapper-validation.yml rename to .github.upstream/workflows/gradle-wrapper-validation.yml diff --git a/.github/workflows/labeler.yml b/.github.upstream/workflows/labeler.yml similarity index 100% rename from .github/workflows/labeler.yml rename to .github.upstream/workflows/labeler.yml diff --git a/.github/workflows/lint.yml b/.github.upstream/workflows/lint.yml similarity index 99% rename from .github/workflows/lint.yml rename to .github.upstream/workflows/lint.yml index 83d2a4bfd6..91f9a8ee3d 100644 --- a/.github/workflows/lint.yml +++ b/.github.upstream/workflows/lint.yml @@ -3,8 +3,8 @@ name: Lint on: push: branches: - - master - main + - rel-* pull_request: jobs: diff --git a/.github/workflows/linux.yml b/.github.upstream/workflows/linux.yml similarity index 88% rename from .github/workflows/linux.yml rename to .github.upstream/workflows/linux.yml index 92ddf900d9..4c84406fb7 100644 --- a/.github/workflows/linux.yml +++ b/.github.upstream/workflows/linux.yml @@ -2,10 +2,14 @@ name: Linux_CI on: push: branches: - - master - main + - rel-* pull_request: +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + jobs: Onnxruntime-TVM: runs-on: ubuntu-latest diff --git a/.github/workflows/publish-c-apidocs.yml b/.github.upstream/workflows/publish-c-apidocs.yml similarity index 100% rename from .github/workflows/publish-c-apidocs.yml rename to .github.upstream/workflows/publish-c-apidocs.yml diff --git a/.github/workflows/publish-csharp-apidocs.yml b/.github.upstream/workflows/publish-csharp-apidocs.yml similarity index 98% rename from .github/workflows/publish-csharp-apidocs.yml rename to .github.upstream/workflows/publish-csharp-apidocs.yml index f038b66dec..eb327fc645 100644 --- a/.github/workflows/publish-csharp-apidocs.yml +++ b/.github.upstream/workflows/publish-csharp-apidocs.yml @@ -28,7 +28,7 @@ jobs: - name: Setup .NET uses: actions/setup-dotnet@v2 with: - dotnet-version: 5.0.x + dotnet-version: 6.0.x - name: Restore dependencies run: dotnet restore csharp/ApiDocs/ApiDocs.csproj - name: Download DocFX diff --git a/.github/workflows/publish-java-apidocs.yml b/.github.upstream/workflows/publish-java-apidocs.yml similarity index 100% rename from .github/workflows/publish-java-apidocs.yml rename to .github.upstream/workflows/publish-java-apidocs.yml diff --git a/.github/workflows/publish-python-apidocs.yml b/.github.upstream/workflows/publish-python-apidocs.yml similarity index 80% rename from .github/workflows/publish-python-apidocs.yml rename to .github.upstream/workflows/publish-python-apidocs.yml index c0a49e348c..5f3e172fc7 100644 --- a/.github/workflows/publish-python-apidocs.yml +++ b/.github.upstream/workflows/publish-python-apidocs.yml @@ -35,19 +35,19 @@ jobs: python3 -m pip install --upgrade pip cd docs/python python3 -m pip install -r requirements.txt - python3 -m pip install -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ ort-nightly + python3 -m pip install --pre onnxruntime-training -f https://download.onnxruntime.ai/onnxruntime_nightly_cpu.html python3 -m pip list - name: Generate Python docs with Sphinx run: | cd tools/doc ./builddoc.sh /usr/bin ../.. ../../build - name: Log source commit - run: git rev-parse --short HEAD > build/docs/inference/html/source-version.txt + run: git rev-parse --short HEAD > build/docs/html/source-version.txt - name: Move Python docs into site run: | rm -rf _site/docs/api/python - mkdir -p _site/docs/api - mv build/docs/inference/html _site/docs/api/python + mkdir -p _site/docs/api/ + mv build/docs/html _site/docs/api/python - name: Upload docs artifact uses: actions/upload-artifact@v3 with: diff --git a/.github/workflows/windows.yml b/.github.upstream/workflows/windows.yml similarity index 89% rename from .github/workflows/windows.yml rename to .github.upstream/workflows/windows.yml index f2c61d3359..ffa91aa221 100644 --- a/.github/workflows/windows.yml +++ b/.github.upstream/workflows/windows.yml @@ -2,10 +2,14 @@ name: Windows_CI on: push: branches: - - master - main + - rel-* pull_request: +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + jobs: Onnxruntime-TVM: runs-on: windows-2019 diff --git a/.github/ISSUE_TEMPLATE/04-web.yml b/.github/ISSUE_TEMPLATE/04-web.yml index 15919a5983..d84226ff5b 100644 --- a/.github/ISSUE_TEMPLATE/04-web.yml +++ b/.github/ISSUE_TEMPLATE/04-web.yml @@ -55,8 +55,10 @@ body: attributes: label: Execution Provider options: - - WebGL - - WASM + - "'webgl' (WebGL)" + - "'wasm'/'cpu' (WebAssembly CPU)" + - "'xnnpack' (WebAssembly XNNPACK)" + - "'webgpu' (WebGPU)" - Other / Unknown multiple: yes validations: diff --git a/.github/workflows/publish-gh-pages.yml b/.github/workflows/publish-gh-pages.yml index 5ddb1e3bb0..193800442f 100644 --- a/.github/workflows/publish-gh-pages.yml +++ b/.github/workflows/publish-gh-pages.yml @@ -84,7 +84,7 @@ jobs: sudo mv apidocs/docs/api/python _site/docs/api - name: Upload site - uses: actions/upload-pages-artifact@v1 + uses: actions/upload-pages-artifact@v2 with: retention-days: 21 diff --git a/.github/workflows/publish-objectivec-apidocs.yml b/.github/workflows/publish-objectivec-apidocs.yml index 518ea8e69a..b966793cc0 100644 --- a/.github/workflows/publish-objectivec-apidocs.yml +++ b/.github/workflows/publish-objectivec-apidocs.yml @@ -21,7 +21,7 @@ permissions: jobs: build: name: Generate Objective-C API docs - runs-on: macos-12 + runs-on: macos-13 steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/sca.yml b/.github/workflows/sca.yml new file mode 100644 index 0000000000..4b19e71744 --- /dev/null +++ b/.github/workflows/sca.yml @@ -0,0 +1,133 @@ +name: Windows_SCA +on: + push: + branches: + - main + - rel-* + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +env: + AZCOPY_AUTO_LOGIN_TYPE: MSI + AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + +jobs: + Onnxruntime-SCA-training-CUDA: + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] + steps: + - uses: actions/checkout@v3 + with: + submodules: false + - uses: actions/setup-python@v3 + with: + python-version: '3.11.x' + architecture: 'x64' + + - uses: actions/setup-node@v3 + with: + node-version: 18 + + - name: Download cuda + run: azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v11.8" cuda_sdk + + + - name: Delete build folder + run: | + if (Test-Path D:\b) { Remove-Item -Recurse -Force D:\b } + &tools\ci_build\github\windows\install_third_party_deps.ps1 -cpu_arch x64 -install_prefix D:\b\Debug\installed -build_config Debug + + # The build machine doesn't have a GPU. So the value of CMAKE_CUDA_ARCHITECTURES doesn't matter. + - name: Build code + env: + CAExcludePath: 'C:\Program Files;D:\b;${{ github.workspace }}\cmake' + run: python tools\ci_build\build.py --enable_training --build_java --compile_no_warning_as_error --config Debug --build_dir D:\b --skip_submodule_sync --build_csharp --update --build --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_pybind --cmake_extra_defines onnxruntime_USE_CUSTOM_STATIC_ANALYSIS_RULES=ON --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON --cmake_extra_defines onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE=ON --use_cuda --cuda_home=${{ github.workspace }}\cuda_sdk\v11.8 --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75 + + - name: Generate sarif + working-directory: D:\b + run: npx @microsoft/sarif-multitool merge *.sarif --recurse --output-directory=${{ github.workspace }}\output --output-file=MergeResult.sarif --merge-runs && dir ${{ github.workspace }}\output + + - name: Upload SARIF to GitHub + uses: github/codeql-action/upload-sarif@v2 + continue-on-error: true + with: + sarif_file: ${{ github.workspace }}\output\MergeResult.sarif + category: VS_SCA + + # No python + Onnxruntime-SCA-win32-WINML-x64: + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] + steps: + - uses: actions/checkout@v3 + with: + submodules: false + - uses: actions/setup-python@v3 + with: + python-version: '3.11.x' + architecture: 'x64' + + - uses: actions/setup-node@v3 + with: + node-version: 18 + + - name: Delete build folder + run: | + if (Test-Path D:\b) { Remove-Item -Recurse -Force D:\b } + &tools\ci_build\github\windows\install_third_party_deps.ps1 -cpu_arch x64 -install_prefix D:\b\Debug\installed -build_config Debug + + # The build machine doesn't have a GPU. So the value of CMAKE_CUDA_ARCHITECTURES doesn't matter. + - name: Build code + env: + CAExcludePath: 'C:\Program Files;D:\b;${{ github.workspace }}\cmake' + run: python tools\ci_build\build.py --enable_training --build_java --compile_no_warning_as_error --config Debug --build_dir D:\b --skip_submodule_sync --build_csharp --update --build --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --cmake_extra_defines onnxruntime_USE_CUSTOM_STATIC_ANALYSIS_RULES=ON --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON --cmake_extra_defines onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE=ON --ms_experimental --use_dml --use_winml --disable_rtti --enable_wcos --build_shared_lib + + - name: Generate sarif + working-directory: D:\b + run: npx @microsoft/sarif-multitool merge *.sarif --recurse --output-directory=${{ github.workspace }}\output --output-file=MergeResult.sarif --merge-runs && dir ${{ github.workspace }}\output + + - name: Upload SARIF to GitHub + uses: github/codeql-action/upload-sarif@v2 + continue-on-error: true + with: + sarif_file: ${{ github.workspace }}\output\MergeResult.sarif + category: VS_SCA_WIN32_WINML_X64 + + # No java, No python + Onnxruntime-SCA-win32-WINML-x86: + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] + steps: + - uses: actions/checkout@v3 + with: + submodules: false + - uses: actions/setup-python@v3 + with: + python-version: '3.11.x' + architecture: 'x86' + + - uses: actions/setup-node@v3 + with: + node-version: 18 + + - name: Delete build folder + run: | + if (Test-Path D:\b) { Remove-Item -Recurse -Force D:\b } + &tools\ci_build\github\windows\install_third_party_deps.ps1 -cpu_arch x86 -install_prefix D:\b\Debug\installed -build_config Debug + + # The build machine doesn't have a GPU. So the value of CMAKE_CUDA_ARCHITECTURES doesn't matter. + - name: Build code + env: + CAExcludePath: 'C:\Program Files;D:\b;${{ github.workspace }}\cmake' + run: python tools\ci_build\build.py --enable_training --compile_no_warning_as_error --config Debug --build_dir D:\b --skip_submodule_sync --build_csharp --update --build --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --cmake_extra_defines onnxruntime_USE_CUSTOM_STATIC_ANALYSIS_RULES=ON --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON --cmake_extra_defines onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE=ON --ms_experimental --use_dml --use_winml --disable_rtti --enable_wcos --build_shared_lib + + - name: Generate sarif + working-directory: D:\b + run: npx @microsoft/sarif-multitool merge *.sarif --recurse --output-directory=${{ github.workspace }}\output --output-file=MergeResult.sarif --merge-runs && dir ${{ github.workspace }}\output + + - name: Upload SARIF to GitHub + uses: github/codeql-action/upload-sarif@v2 + continue-on-error: true + with: + sarif_file: ${{ github.workspace }}\output\MergeResult.sarif + category: VS_SCA_WIN32_WINML_X86 diff --git a/.github/workflows/wheel.yaml b/.github/workflows/wheel.yaml new file mode 100644 index 0000000000..b4f7b984ef --- /dev/null +++ b/.github/workflows/wheel.yaml @@ -0,0 +1,47 @@ +name: CI && Release & Upload Wheel + +on: + workflow_dispatch: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + build_and_upload_wheel: + runs-on: linux-mk + container: + image: ghcr.io/quadric-io/tvm:devel + options: "--mount type=bind,source=${{ github.workspace }},target=/workspace" + steps: + - name: Checkout repository + uses: actions/checkout@v3 + - name: Build ONNX Runtime wheel + working-directory: /workspace + run: | + python3 -m pip install cmake --upgrade + ./build.sh --build_wheel --config Release --parallel ${{ github.event_name == 'pull_request' && ' ' || '--skip_tests'}} --skip_submodule_sync + wheel_path=$(find . -name '*.whl' | xargs readlink -f) + echo "wheel_path=$wheel_path" >> $GITHUB_ENV + - name: Count releases + id: count_releases + if: github.event_name != 'pull_request' + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + apt-get update && apt-get install curl jq -y + count=$(curl --request GET \ + --url https://api.github.com/repos/${{ github.repository }}/releases \ + --header "Authorization: Bearer $GITHUB_TOKEN" | jq length) + echo "count=$count" >> $GITHUB_ENV + - name: Create Release + if: github.event_name != 'pull_request' + uses: softprops/action-gh-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: v${{ env.count }} + name: Release v${{ env.count }} + files: ${{ env.wheel_path }} diff --git a/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml b/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml index fa2b475fe9..b9de1b79e1 100644 --- a/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml +++ b/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml @@ -351,6 +351,31 @@ extends: - script: | dir $(Build.SourcesDirectory)\unzipped\runtimes\win-x64\_native + - task: EsrpCodeSigning@2 + displayName: "Sign Nuget package" + inputs: + ConnectedServiceName: 'OnnxRuntime CodeSign 20190817' + FolderPath: $(Build.ArtifactStagingDirectory) + Pattern: '*.nupkg' + signConfigType: inlineSignParams + inlineOperation: | + [ + { + "keyCode": "CP-401405", + "operationSetCode": "NuGetSign", + "parameters": [ ], + "toolName": "sign", + "toolVersion": "1.0" + }, + { + "keyCode": "CP-401405", + "operationSetCode": "NuGetVerify", + "parameters": [ ], + "toolName": "sign", + "toolVersion": "1.0" + } + ] + - job: NuGet_Publishing pool: type: windows diff --git a/.pipelines/nuget_config/x64/packages.config b/.pipelines/nuget_config/x64/packages.config index 2ec9b577c1..8eef0b5bac 100644 --- a/.pipelines/nuget_config/x64/packages.config +++ b/.pipelines/nuget_config/x64/packages.config @@ -1,6 +1,6 @@  - + diff --git a/.pipelines/nuget_config/x86/packages.config b/.pipelines/nuget_config/x86/packages.config index 6f9e3ef496..81f97948f1 100644 --- a/.pipelines/nuget_config/x86/packages.config +++ b/.pipelines/nuget_config/x86/packages.config @@ -1,6 +1,6 @@  - + diff --git a/.pipelines/windowsai-steps.yml b/.pipelines/windowsai-steps.yml index 0b736da427..a29e2fe6c8 100644 --- a/.pipelines/windowsai-steps.yml +++ b/.pipelines/windowsai-steps.yml @@ -80,11 +80,11 @@ jobs: # must call vsdevcmd first to add cmake to PATH - script: | - curl -O -L https://github.com/Kitware/CMake/releases/download/v3.24.3/cmake-3.24.3-windows-x86_64.zip - 7z x cmake-3.24.3-windows-x86_64.zip + curl -O -L https://github.com/Kitware/CMake/releases/download/v3.26.3/cmake-3.26.3-windows-x86_64.zip + 7z x cmake-3.26.3-windows-x86_64.zip set PYTHONHOME=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools set PYTHONPATH=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools - $(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 16 2019" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos $(BuildFlags) --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.19041.0 --cmake_path $(Build.BinariesDirectory)\cmake-3.24.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.24.3-windows-x86_64\bin\ctest.exe + $(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 16 2019" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos $(BuildFlags) --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.19041.0 --cmake_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\ctest.exe workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Generate cmake config' diff --git a/Package.swift b/Package.swift index 7f8bfe0c3c..54609b104d 100644 --- a/Package.swift +++ b/Package.swift @@ -79,24 +79,10 @@ if let pod_archive_path = ProcessInfo.processInfo.environment["ORT_IOS_POD_LOCAL package.targets.append(Target.binaryTarget(name: "onnxruntime", path: pod_archive_path)) } else { - // When creating the release version: - // - remove the fatalError - // - uncomment the package.targets.append call - // - update the major/minor/patch version info in the url - // - insert the checksum info from the onnxruntime-ios-packaging-pipeline CI's 'Print ORT iOS Pod checksum' - // stage output (or download the pod archive artifact from the CI and run `shasum -a 256 ` - // to manually calculate it). - // The checksum length and chars should look something like - // "c89cd106ff02eb3892243acd7c4f2bd8e68c2c94f2751b5e35f98722e10c042b" - // - // package.targets.append( - // Target.binaryTarget(name: "onnxruntime", - // url: "https://onnxruntimepackages.z14.web.core.windows.net/pod-archive-onnxruntime-c-.zip", - // checksum: "Insert checksum here") - // ) - - fatalError("It is not valid to use a non-release branch from https://github.com/microsoft/onnxruntime.\n" + - "Please use a release branch (e.g. rel-1.15.0), or build the ONNX Runtime iOS pod archive locally " + - "and set the ORT_IOS_POD_LOCAL_PATH environment variable.\n" + - "See Package.swift for more information on using a local pod archive.") + // ORT 1.15.0 release + package.targets.append( + Target.binaryTarget(name: "onnxruntime", + url: "https://onnxruntimepackages.z14.web.core.windows.net/pod-archive-onnxruntime-c-1.15.0.zip", + checksum: "9b41412329a73d7d298b1d94ab40ae9adb65cb84f132054073bc82515b4f5f82") + ) } diff --git a/README_EPU.md b/README_EPU.md new file mode 100644 index 0000000000..e42cbe176d --- /dev/null +++ b/README_EPU.md @@ -0,0 +1,27 @@ +# The Quadric Version of onnxruntime + +This repository contains the a distribution of onnxruntime with additional operator quantization capabilities. + + +## Prerequisites: +- python 3.9 +- pip + +## Clone repository and build: +``` +git clone --recursive https://github.com/quadric-io/onnxruntime onnxruntime +cd onnxruntime +# Install wheel +pip3 install wheel +# Build the python package +./build.sh --build_wheel --config Release --parallel --compile_no_warning_as_error --skip_tests --skip_submodule_sync +``` + +## Install +``` +# Find the wheel you just created +$ find . -name '*.whl' +./build/MacOS/Release/dist/onnxruntime-1.15.1-cp310-cp310-macosx_13_0_arm64.whl +# Install it +pip3 install ./build/MacOS/Release/dist/onnxruntime-1.15.1-cp310-cp310-macosx_13_0_arm64.whl +``` diff --git a/ThirdPartyNotices.txt b/ThirdPartyNotices.txt index b4d981d42d..e099840e64 100644 --- a/ThirdPartyNotices.txt +++ b/ThirdPartyNotices.txt @@ -5993,3 +5993,32 @@ https://github.com/tensorflow/tfjs WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + +—— + +curl/curl + +https://github.com/curl + +COPYRIGHT AND PERMISSION NOTICE + +Copyright (C) Daniel Stenberg, , and many +contributors, see the THANKS file. + +All rights reserved. + +Permission to use, copy, modify, and distribute this software for any purpose +with or without fee is hereby granted, provided that the above copyright +notice and this permission notice appear in all copies. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. IN +NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE +OR OTHER DEALINGS IN THE SOFTWARE. + +Except as contained in this notice, the name of a copyright holder shall not +be used in advertising or otherwise to promote the sale, use or other dealings +in this Software without prior written authorization of the copyright holder. \ No newline at end of file diff --git a/VERSION_NUMBER b/VERSION_NUMBER index 141f2e805b..ace44233b4 100644 --- a/VERSION_NUMBER +++ b/VERSION_NUMBER @@ -1 +1 @@ -1.15.0 +1.15.1 diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 989756361b..4f4c1e7e74 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -112,7 +112,7 @@ "component": { "type": "git", "git": { - "commitHash": "9b7bca2a723ff94edcd007d93b5d0cf1838591dc", + "commitHash": "a0d77f18516d2da7468a96b0de3b737266f23176", "repositoryUrl": "https://github.com/onnx/onnx.git" }, "comments": "git submodule at cmake/external/onnx" @@ -288,16 +288,6 @@ "comments": "mp11" } }, - { - "component": { - "type": "git", - "git": { - "commitHash": "3b58938e025c41d2fcd89fa22028eefaa81a18ad", - "repositoryUrl": "https://github.com/onnx/onnx.git" - }, - "comments": "onnx" - } - }, { "component": { "type": "git", @@ -447,6 +437,26 @@ }, "comments": "triton" } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "94142d8391c9791ec71c38336436319a2d4ac7a0", + "repositoryUrl": "https://github.com/microsoft/onnxruntime-extensions.git" + }, + "comments": "extensions" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "b16d1fa8ee567b52c09a0f89940b07d8491b881d", + "repositoryUrl": "https://github.com/curl/curl.git" + }, + "comments": "curl" + } } ] } diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 0b2f980ab2..2e54063598 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -226,6 +226,7 @@ option(onnxruntime_BUILD_CACHE "onnxruntime build with cache" OFF) cmake_dependent_option(MSVC_Z7_OVERRIDE "replacing /Zi and /ZI with /Z7 when using MSVC with CCache" ON "onnxruntime_BUILD_CACHE; MSVC" OFF) option(onnxruntime_USE_AZURE "Build with azure inferencing support" OFF) +option(onnxruntime_USE_LOCK_FREE_QUEUE "Build with lock-free task queue for threadpool." OFF) # ENABLE_TRAINING includes all training functionality # The following 2 entry points @@ -272,6 +273,9 @@ if (onnxruntime_USE_ROCM) file(GLOB rocm_cmake_components ${onnxruntime_ROCM_HOME}/lib/cmake/*) list(APPEND CMAKE_PREFIX_PATH ${rocm_cmake_components}) + # Force cmake to accept the configured HIP compiler. Because the configured CMAKE_PREFIX_PATH does not work during + # enable_language(HIP), we might need to move configuring of CMAKE_PREFIX_PATH to build.py (in the future). + set(CMAKE_HIP_COMPILER_FORCED ON) enable_language(HIP) # NOTE: Flags -mllvm -amdgpu-early-inline-all=true are critical for gpu kernel code performance. -mllvm passes the @@ -740,7 +744,9 @@ if (onnxruntime_USE_AZURE) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_AZURE=1) list(APPEND ONNXRUNTIME_PROVIDER_NAMES azure) endif() - +if (onnxruntime_USE_LOCK_FREE_QUEUE) + add_compile_definitions(USE_LOCK_FREE_QUEUE) +endif() if (onnxruntime_ENABLE_LAZY_TENSOR) # To support LazyTensor, ORT needs to call Python function from C/C++. diff --git a/cmake/deps.txt b/cmake/deps.txt index 6b7fb0c95f..b887418256 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -23,7 +23,7 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36 microsoft_wil;https://github.com/microsoft/wil/archive/5f4caba4e7a9017816e47becdd918fcc872039ba.zip;fd119887d0d17c37adf1fc227b054befa28158ad mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.79.0.zip;c8f04e378535ededbe5af52c8f969d2dedbe73d5 -onnx;https://github.com/onnx/onnx/archive/3b58938e025c41d2fcd89fa22028eefaa81a18ad.zip;e0e5dda9eea5cd5ecae3bd8be86e477016b6be02 +onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.14.0.zip;3c6e43a36f94addc15afe939860127a1d74a9488 #use the last commit of 8.6-EA branch (https://github.com/onnx/onnx-tensorrt/commit/ba6a4fb34fdeaa3613bf981610c657e7b663a699) onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/ba6a4fb34fdeaa3613bf981610c657e7b663a699.zip;5a474ed86e2c4ee4085d3daeff8222866e933dc0 protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa @@ -48,4 +48,5 @@ b64;https://github.com/libb64/libb64/archive/refs/tags/v2.0.0.1.zip;815b6d31d50d pthread;https://sourceforge.net/projects/pthreads4w/files/pthreads4w-code-v3.0.0.zip;3b9e417e4474c34542b76ad40529e396ac109fb4 triton;https://github.com/triton-inference-server/server/archive/refs/tags/v2.28.0.zip;4b305570aa1e889946e20e36050b6770e4108fee # above are deps introduced by triton client, might remove after 1.14 release -extensions;https://github.com/microsoft/onnxruntime-extensions/archive/81e7799c69044c745239202085eb0a98f102937b.zip;d53487035174a046628359289ad27aa0ac0380c9 +extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c +curl;https://github.com/curl/curl/archive/refs/tags/curl-8_0_1.zip;b16d1fa8ee567b52c09a0f89940b07d8491b881d \ No newline at end of file diff --git a/cmake/external/dml.cmake b/cmake/external/dml.cmake index 161db093a1..1ee3a06b41 100644 --- a/cmake/external/dml.cmake +++ b/cmake/external/dml.cmake @@ -41,7 +41,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML) set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config) set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config) get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE) - set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.10.1) + set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.12.0) # Restore nuget packages, which will pull down the DirectML redist package. add_custom_command( diff --git a/cmake/external/emsdk b/cmake/external/emsdk index 0ab19024f0..b113f24842 160000 --- a/cmake/external/emsdk +++ b/cmake/external/emsdk @@ -1 +1 @@ -Subproject commit 0ab19024f08c6673a713e454ef8bd95e174c807f +Subproject commit b113f24842c6e97fe3e352084db09a6e278593ae diff --git a/cmake/external/onnx b/cmake/external/onnx index 9b7bca2a72..a0d77f1851 160000 --- a/cmake/external/onnx +++ b/cmake/external/onnx @@ -1 +1 @@ -Subproject commit 9b7bca2a723ff94edcd007d93b5d0cf1838591dc +Subproject commit a0d77f18516d2da7468a96b0de3b737266f23176 diff --git a/cmake/external/triton.cmake b/cmake/external/triton.cmake index b24768bd89..41da407adf 100644 --- a/cmake/external/triton.cmake +++ b/cmake/external/triton.cmake @@ -44,13 +44,11 @@ if (WIN32) vcpkg_install(re2) vcpkg_install(boost-interprocess) vcpkg_install(boost-stacktrace) - vcpkg_install(zlib) vcpkg_install(pthread) vcpkg_install(b64) add_dependencies(getb64 getpthread) - add_dependencies(getpthread getzlib) - add_dependencies(getzlib getboost-stacktrace) + add_dependencies(getpthread getboost-stacktrace) add_dependencies(getboost-stacktrace getboost-interprocess) add_dependencies(getboost-interprocess getre2) add_dependencies(getre2 getrapidjson) @@ -59,11 +57,11 @@ if (WIN32) ExternalProject_Add(triton GIT_REPOSITORY https://github.com/triton-inference-server/client.git - GIT_TAG r22.12 + GIT_TAG r23.05 PREFIX triton SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-src BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-build - CMAKE_ARGS -DVCPKG_TARGET_TRIPLET=${onnxruntime_target_platform}-windows -DCMAKE_TOOLCHAIN_FILE=${VCPKG_SRC}/scripts/buildsystems/vcpkg.cmake -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON + CMAKE_ARGS -DVCPKG_TARGET_TRIPLET=${onnxruntime_target_platform}-windows -DCMAKE_TOOLCHAIN_FILE=${VCPKG_SRC}/scripts/buildsystems/vcpkg.cmake -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON -DTRITON_ENABLE_ZLIB=OFF INSTALL_COMMAND "" UPDATE_COMMAND "") @@ -77,7 +75,9 @@ else() PREFIX rapidjson SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/rapidjson-src BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/rapidjson-build - CMAKE_ARGS -DRAPIDJSON_BUILD_TESTS=OFF -DRAPIDJSON_BUILD_DOC=OFF -DRAPIDJSON_BUILD_EXAMPLES=OFF) + CMAKE_ARGS -DRAPIDJSON_BUILD_TESTS=OFF -DRAPIDJSON_BUILD_DOC=OFF -DRAPIDJSON_BUILD_EXAMPLES=OFF + INSTALL_COMMAND "" + UPDATE_COMMAND "") ExternalProject_Get_Property(rapidjson source_dir) set(RAPIDJSON_INCLUDE_DIR ${source_dir}/include) @@ -85,16 +85,14 @@ else() ExternalProject_Add(triton GIT_REPOSITORY https://github.com/triton-inference-server/client.git - GIT_TAG r22.12 + GIT_TAG r23.05 PREFIX triton SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-src BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-build - CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON -DTRITON_ENABLE_ZLIB=OFF INSTALL_COMMAND "" UPDATE_COMMAND "") - add_dependencies(triton rapidjson) - endif() #if (WIN32) ExternalProject_Get_Property(triton SOURCE_DIR) diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 02861458c9..92f59b8326 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -28,9 +28,9 @@ macro(get_mobile_api_headers _HEADERS) ) if (onnxruntime_ENABLE_TRAINING_APIS) - list(APPEND ${_HEADERS} "${REPO_ROOT/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h}") - list(APPEND ${_HEADERS} "${REPO_ROOT/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h}") - list(APPEND ${_HEADERS} "${REPO_ROOT/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline_api.h}") + list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h") + list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h") + list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h") endif() # need to add header files for enabled EPs diff --git a/cmake/onnxruntime_config.h.in b/cmake/onnxruntime_config.h.in index 9fc46018da..4f5125569c 100644 --- a/cmake/onnxruntime_config.h.in +++ b/cmake/onnxruntime_config.h.in @@ -21,5 +21,5 @@ #cmakedefine HAS_FORMAT_TRUNCATION #cmakedefine HAS_BITWISE_INSTEAD_OF_LOGICAL #cmakedefine HAS_REALLOCARRAY -#cmakedefine ORT_VERSION "@ORT_VERSION@" -#cmakedefine ORT_BUILD_INFO "@ORT_BUILD_INFO@" +#cmakedefine ORT_VERSION u8"@ORT_VERSION@" +#cmakedefine ORT_BUILD_INFO u8"@ORT_BUILD_INFO@" diff --git a/cmake/onnxruntime_framework.cmake b/cmake/onnxruntime_framework.cmake index 5c947a52b7..1254cd16f2 100644 --- a/cmake/onnxruntime_framework.cmake +++ b/cmake/onnxruntime_framework.cmake @@ -40,13 +40,13 @@ onnxruntime_add_static_library(onnxruntime_framework ${onnxruntime_framework_src if (onnxruntime_USE_AZURE) add_dependencies(onnxruntime_framework triton) - target_include_directories(onnxruntime_framework PRIVATE ${TRITON_BIN}/include) + target_include_directories(onnxruntime_framework PRIVATE ${TRITON_BIN}/include ${TRITON_THIRD_PARTY}/curl/include) link_directories(${TRITON_BIN}/lib ${TRITON_BIN}/lib64 ${TRITON_THIRD_PARTY}/curl/lib ${TRITON_THIRD_PARTY}/curl/lib64) if (WIN32) link_directories(${VCPKG_SRC}/installed/${onnxruntime_target_platform}-windows/lib) - target_link_libraries(onnxruntime_framework PRIVATE libcurl httpclient_static ws2_32 crypt32 Wldap32 zlib) + target_link_libraries(onnxruntime_framework PRIVATE libcurl httpclient_static ws2_32 crypt32 Wldap32) else() diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 2839713406..0daa1b8a3d 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -594,7 +594,7 @@ if (onnxruntime_USE_DNNL) add_dependencies(onnxruntime_providers_dnnl onnxruntime_providers_shared project_dnnl ${onnxruntime_EXTERNAL_DEPENDENCIES}) target_include_directories(onnxruntime_providers_dnnl PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} ${DNNL_INCLUDE_DIR} ${DNNL_OCL_INCLUDE_DIR}) # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found - target_link_libraries(onnxruntime_providers_dnnl PRIVATE dnnl ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 ${ABSEIL_LIBS} ${GSL_TARGET}) + target_link_libraries(onnxruntime_providers_dnnl PRIVATE dnnl ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 ${ABSEIL_LIBS} ${GSL_TARGET} safeint_interface) install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/dnnl DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers) set_target_properties(onnxruntime_providers_dnnl PROPERTIES FOLDER "ONNXRuntime") set_target_properties(onnxruntime_providers_dnnl PROPERTIES LINKER_LANGUAGE CXX) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 56a7a4b350..9347be180d 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -848,7 +848,7 @@ if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() if (onnxruntime_BUILD_WEBASSEMBLY) set_target_properties(onnxruntime_test_all PROPERTIES LINK_DEPENDS ${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js) - set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s STACK_SIZE=1048576 -s ALLOW_MEMORY_GROWTH=1 --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1") + set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1") if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) set_property(TARGET onnxruntime_test_all APPEND_STRING PROPERTY LINK_FLAGS " -s USE_PTHREADS=1 -s PROXY_TO_PTHREAD=1") endif() diff --git a/csharp/OnnxRuntime.CSharp.proj b/csharp/OnnxRuntime.CSharp.proj index 5473246e8b..0288d752d8 100644 --- a/csharp/OnnxRuntime.CSharp.proj +++ b/csharp/OnnxRuntime.CSharp.proj @@ -16,6 +16,7 @@ CMake creates a target to this project nuget x64 false + false None @@ -77,6 +78,7 @@ CMake creates a target to this project $([System.DateTime]::UtcNow.ToString(yyyyMMdd)) $([System.DateTime]::UtcNow.ToString(hhmm)) @(MajorVersionNumber) + $(PackageVersion)$(ReleaseVersionSuffix) $(PackageVersion) $(PackageVersion)-dev-$(CurrentDate)-$(CurrentTime)-$(GitCommitHashShort) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj index 78083a8cc1..ad468b0c6d 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj @@ -41,6 +41,7 @@ net6.0;net6.0-android;net6.0-ios;net6.0-macos diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 14e93dc05f..5fcddf7d0c 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -286,6 +286,7 @@ public struct OrtApi public IntPtr CastTypeInfoToOptionalTypeInfo; public IntPtr GetOptionalContainedTypeInfo; public IntPtr GetResizedStringTensorElementBuffer; + public IntPtr KernelContext_GetAllocator; } internal static class NativeMethods @@ -1654,7 +1655,6 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DOrtModelMetadataLookupCustomMetadataMap OrtModelMetadataLookupCustomMetadataMap; - /// /// Frees ModelMetadata instance /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs index 2df16ccae0..1a03338298 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs @@ -276,7 +276,7 @@ public static OrtEnv Instance() /// /// /// if the singleton has already been created - public static OrtEnv CreateInstanceWithOptions(EnvironmentCreationOptions options) + public static OrtEnv CreateInstanceWithOptions(ref EnvironmentCreationOptions options) { // Non-thread safe, best effort hopefully helpful check. // Environment is usually created once per process, so this should be fine. diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index 6bbb159a3d..30951bae3f 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -375,10 +375,10 @@ public IntPtr Appender(IntPtr handle, IntPtr[] optKeys, IntPtr[] optValues, UInt /// Optional key/value pairs to specify execution provider options. public void AppendExecutionProvider(string providerName, Dictionary providerOptions = null) { - if (providerName != "SNPE" && providerName != "XNNPACK" && providerName != "QNN") + if (providerName != "SNPE" && providerName != "XNNPACK" && providerName != "QNN" && providerName != "AZURE") { throw new NotSupportedException( - "Only QNN, SNPE and XNNPACK execution providers can be enabled by this method."); + "Only QNN, SNPE, XNNPACK and AZURE execution providers can be enabled by this method."); } if (providerOptions == null) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs index 5fb983ab37..83d114b535 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs @@ -401,21 +401,22 @@ public void FromBuffer(FixedBufferOnnxValue buffer) throw new ArgumentException(errorMessage); } - IntPtr numElementsTrainingOnly = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out numElementsTrainingOnly)); - - UIntPtr bufferSize = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, true)); - if ((long)bufferSize.ToUInt64() == numElementsTrainingOnly.ToInt64()) + // Here buffer size represents the number of elements in the buffer + IntPtr bufferSize = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out bufferSize)); + + // OrtGetParametersSize returns the total number of elements in the model's parameters. + UIntPtr numElementsTrainingOnly = UIntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElementsTrainingOnly, true)); + if (bufferSize.ToInt64() == (long)numElementsTrainingOnly.ToUInt64()) { NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, true)); return; } - IntPtr numElements = IntPtr.Zero; - bufferSize = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, false)); - if ((long)bufferSize.ToUInt64() != numElements.ToInt64()) + UIntPtr numElements = UIntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, false)); + if (bufferSize.ToInt64() != (long)numElements.ToUInt64()) { string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString(); throw new ArgumentException(errorMessage); diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs index 63ddf51116..8a94b199ff 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs @@ -80,7 +80,7 @@ public void TestUpdatingEnvWithCustomLogLevel() logLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL }; - ortEnvInstance = OrtEnv.CreateInstanceWithOptions(envOptions); + ortEnvInstance = OrtEnv.CreateInstanceWithOptions(ref envOptions); Assert.True(OrtEnv.IsCreated); Assert.Equal(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, ortEnvInstance.EnvLogLevel); @@ -92,7 +92,7 @@ public void TestUpdatingEnvWithCustomLogLevel() logId = "CSharpOnnxRuntimeTestLogid" }; - ortEnvInstance = OrtEnv.CreateInstanceWithOptions(envOptions); + ortEnvInstance = OrtEnv.CreateInstanceWithOptions(ref envOptions); Assert.Equal(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING, ortEnvInstance.EnvLogLevel); // Change and see if this takes effect @@ -118,7 +118,7 @@ public void TestUpdatingEnvWithThreadingOptions() }; // Make sure we start anew - var env = OrtEnv.CreateInstanceWithOptions(envOptions); + var env = OrtEnv.CreateInstanceWithOptions(ref envOptions); Assert.True(OrtEnv.IsCreated); } } @@ -165,7 +165,7 @@ public void TesEnvWithCustomLogger() LoggingInvokes = 0; - var env = OrtEnv.CreateInstanceWithOptions(envOptions); + var env = OrtEnv.CreateInstanceWithOptions(ref envOptions); Assert.True(OrtEnv.IsCreated); var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); @@ -199,7 +199,7 @@ public void TestEnvWithCustomLoggerAndThredingOptions() LoggingInvokes = 0; - var env = OrtEnv.CreateInstanceWithOptions(envOptions); + var env = OrtEnv.CreateInstanceWithOptions(ref envOptions); Assert.True(OrtEnv.IsCreated); var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 8427dfaef4..d4b2310fa1 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -955,6 +955,7 @@ Do not modify directly.* |||7+|**T** = tensor(float), tensor(float16)
**T1** = tensor(bool)| |GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||12+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(float), tensor(float16)| |HardSigmoid|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float), tensor(float16)| |Hardmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float), tensor(float16)| |||11+|**T** = tensor(float), tensor(float16)| @@ -1101,7 +1102,8 @@ Do not modify directly.* |||11+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float), tensor(float16)| |||10+|**T** = tensor(float), tensor(float16)| |ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|10+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)| +|RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)| +|||10+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)| |Round|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(float), tensor(float16)| |STFT|*in* signal:**T1**
*in* frame_step:**T2**
*in* window:**T1**
*in* frame_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)| |ScaledTanh|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -1194,6 +1196,7 @@ Do not modify directly.* |FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| +|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| diff --git a/docs/python/README.rst b/docs/python/README.rst index f5f162b1d4..e47f2ad19d 100644 --- a/docs/python/README.rst +++ b/docs/python/README.rst @@ -8,6 +8,11 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime mtx(spin_lock_); +#else std::lock_guard lock(mutex_); +#endif unsigned back = back_.load(std::memory_order_relaxed); Elem& e = array_[(back - 1) & kMask]; ElemState s = e.state.load(std::memory_order_relaxed); @@ -469,7 +474,11 @@ class RunQueue { // with w_idx. Typically the tag will be a per-thread ID to distinguish work // submitted from different threads. PushResult PushBackWithTag(Work w, Tag tag, unsigned& w_idx) { +#ifdef USE_LOCK_FREE_QUEUE + std::lock_guard mtx(spin_lock_); +#else std::lock_guard lock(mutex_); +#endif unsigned back = back_.load(std::memory_order_relaxed); w_idx = (back - 1) & kMask; Elem& e = array_[w_idx]; @@ -490,7 +499,11 @@ class RunQueue { Work PopBack() { if (Empty()) return Work(); +#ifdef USE_LOCK_FREE_QUEUE + std::lock_guard mtx(spin_lock_); +#else std::lock_guard lock(mutex_); +#endif unsigned back; Elem* e; ElemState s; @@ -532,7 +545,11 @@ class RunQueue { bool RevokeWithTag(Tag tag, unsigned w_idx) { bool revoked = false; +#ifdef USE_LOCK_FREE_QUEUE + std::lock_guard mtx(spin_lock_); +#else std::lock_guard lock(mutex_); +#endif Elem& e = array_[w_idx]; ElemState s = e.state.load(std::memory_order_relaxed); @@ -604,7 +621,11 @@ class RunQueue { Work w; }; +#ifdef USE_LOCK_FREE_QUEUE + OrtSpinLock spin_lock_; +#else OrtMutex mutex_; +#endif // Low log(kSize) + 1 bits in front_ and back_ contain rolling index of // front/back, respectively. The remaining bits contain modification counters diff --git a/include/onnxruntime/core/platform/ort_spin_lock.h b/include/onnxruntime/core/platform/ort_spin_lock.h new file mode 100644 index 0000000000..db80abe1ee --- /dev/null +++ b/include/onnxruntime/core/platform/ort_spin_lock.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include "core/common/spin_pause.h" +#include + +namespace onnxruntime { +/* +OrtSpinLock implemented mutex semantic "lock-freely", +calling thread will not be put to sleep on blocked, +which reduces cpu usage on context switching. +*/ +struct OrtSpinLock { + using LockState = enum { Locked = 0, + Unlocked }; + + void lock() noexcept { + LockState state = Unlocked; + while (!state_.compare_exchange_weak(state, Locked, std::memory_order_acq_rel, std::memory_order_relaxed)) { + state = Unlocked; + concurrency::SpinPause(); // pause and retry + } + } + bool try_lock() noexcept { + LockState state = Unlocked; + return state_.compare_exchange_weak(state, Locked, std::memory_order_acq_rel, std::memory_order_relaxed); + } + void unlock() noexcept { + state_.store(Unlocked, std::memory_order_release); + } + + private: + std::atomic state_{Unlocked}; +}; +} // namespace onnxruntime \ No newline at end of file diff --git a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h index 18dcfeea68..084e66fd81 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h +++ b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h @@ -14,21 +14,21 @@ /// User can only get the instance of OrtCUDAProviderOptionsV2 via CreateCUDAProviderOptions. /// struct OrtCUDAProviderOptionsV2 { - int device_id; // cuda device id. - int has_user_compute_stream; // indicator of user specified CUDA compute stream. - void* user_compute_stream; // user specified CUDA compute stream. - int do_copy_in_default_stream; // flag specifying if the default stream is to be used for copying. - OrtCudnnConvAlgoSearch cudnn_conv_algo_search; // cudnn algo search enum. - size_t gpu_mem_limit; // BFC Arena memory limit for CUDA. - // (will be overridden by contents of `default_memory_arena_cfg` is it exists) - onnxruntime::ArenaExtendStrategy arena_extend_strategy; // BFC Arena extension strategy. - // (will be overridden by contents of `default_memory_arena_cfg` is it exists) - OrtArenaCfg* default_memory_arena_cfg; // BFC Arena config flags. - int cudnn_conv_use_max_workspace; // flag specifying if maximum workspace can be used in cudnn conv algo search. - int enable_cuda_graph; // flag specifying if the CUDA graph is to be captured for the model. - int cudnn_conv1d_pad_to_nc1d; // flag specifying if pad Conv1D's input [N,C,D] to [N,C,1,D] or [N,C,D,1]. - int tunable_op_enable; // flag specifying if TunableOp is enabled. - int tunable_op_tuning_enable; // flag specifying if TunableOp is enabled for tuning, this relies on TunableOp is enabled. - int enable_skip_layer_norm_strict_mode; // flag specifying if SkipLayerNorm is in strict mode. If true, use LayerNormalization kernel. - // The strict mode has better accuracy but lower performance. + int device_id = 0; // cuda device id. + int has_user_compute_stream = 0; // indicator of user specified CUDA compute stream. + void* user_compute_stream = nullptr; // user specified CUDA compute stream. + int do_copy_in_default_stream = 1; // flag specifying if the default stream is to be used for copying. + OrtCudnnConvAlgoSearch cudnn_conv_algo_search = OrtCudnnConvAlgoSearchExhaustive; // cudnn algo search enum. + size_t gpu_mem_limit = std::numeric_limits::max(); // BFC Arena memory limit for CUDA. + // (will be overridden by contents of `default_memory_arena_cfg` is it exists) + onnxruntime::ArenaExtendStrategy arena_extend_strategy = onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo; // BFC Arena extension strategy. + // (will be overridden by contents of `default_memory_arena_cfg` is it exists) + OrtArenaCfg* default_memory_arena_cfg = nullptr; // BFC Arena config flags. + int cudnn_conv_use_max_workspace = 1; // flag specifying if maximum workspace can be used in cudnn conv algo search. + int enable_cuda_graph = 0; // flag specifying if the CUDA graph is to be captured for the model. + int cudnn_conv1d_pad_to_nc1d = 0; // flag specifying if pad Conv1D's input [N,C,D] to [N,C,1,D] or [N,C,D,1]. + int tunable_op_enable = 0; // flag specifying if TunableOp is enabled. + int tunable_op_tuning_enable = 0; // flag specifying if TunableOp is enabled for tuning, this relies on TunableOp is enabled. + int enable_skip_layer_norm_strict_mode = 0; // flag specifying if SkipLayerNorm is in strict mode. If true, use LayerNormalization kernel. + // The strict mode has better accuracy but lower performance. }; diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index 3bd6dac54b..600e255bcd 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -36,7 +36,7 @@ struct OrtTensorRTProviderOptionsV2 { int trt_detailed_build_log; // Enable detailed build step logging on TensorRT EP with timing for each engine build. Default 0 = false, nonzero = true int trt_build_heuristics_enable; // Build engine using heuristics to reduce build time. Default 0 = false, nonzero = true int trt_sparsity_enable; // Control if sparsity can be used by TRT. Default 0 = false, 1 = true - int trt_builder_optimization_level; // Set the builder optimization level. WARNING: levels below 2 do not guarantee good engine performance, but greatly improve build time. Default 2, valid range [0-4] + int trt_builder_optimization_level; // Set the builder optimization level. WARNING: levels below 3 do not guarantee good engine performance, but greatly improve build time. Default 3, valid range [0-5] int trt_auxiliary_streams; // Set maximum number of auxiliary streams per inference stream. Setting this value to 0 will lead to optimal memory usage. Default -1 = heuristics const char* trt_tactic_sources; // pecify the tactics to be used by adding (+) or removing (-) tactics from the default // tactic sources (default = all available tactics) e.g. "-CUDNN,+CUBLAS" available keys: "CUBLAS"|"CUBLAS_LT"|"CUDNN"|"EDGE_MASK_CONVOLUTIONS" diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 7f8cc059a6..d215b12157 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -95,6 +95,8 @@ extern "C" { #define ORTCHAR_T char #endif +/// ORTCHAR_T, ORT_TSTR are reserved specifically for path handling. +/// All other strings are UTF-8 encoded, use char and std::string #ifndef ORT_TSTR #ifdef _WIN32 #define ORT_TSTR(X) L##X @@ -627,11 +629,19 @@ struct OrtApiBase { * \param[in] version Must be ::ORT_API_VERSION * \return The ::OrtApi for the version requested, nullptr will be returned if this version is unsupported, for example when using a runtime * older than the version created with this header file. + * + * One can call GetVersionString() to get the version of the Onnxruntime library for logging + * and error reporting purposes. */ const OrtApi*(ORT_API_CALL* GetApi)(uint32_t version)NO_EXCEPTION; - const ORTCHAR_T*(ORT_API_CALL* GetVersionString)(void)NO_EXCEPTION; ///< Returns a null terminated string of the version of the Onnxruntime library (eg: "1.8.1") - const ORTCHAR_T*(ORT_API_CALL* GetBuildInfoString)(void)NO_EXCEPTION; ///< Returns a null terminated string of the build info including git info and cxx flags + + /** \brief Returns a null terminated string of the version of the Onnxruntime library (eg: "1.8.1") + * + * \return UTF-8 encoded version string. Do not deallocate the returned buffer. + */ + const char*(ORT_API_CALL* GetVersionString)(void)NO_EXCEPTION; }; + typedef struct OrtApiBase OrtApiBase; /** \brief The Onnxruntime library's entry point to access the C API @@ -4192,6 +4202,14 @@ struct OrtApi { * \since Version 1.15. */ ORT_API2_STATUS(KernelContext_GetAllocator, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _Outptr_ OrtAllocator** out); + + /** \brief Returns a null terminated string of the build info including git info and cxx flags + * + * \return UTF-8 encoded version string. Do not deallocate the returned buffer. + * + * \since Version 1.15. + */ + const char*(ORT_API_CALL* GetBuildInfoString)(void); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index df007b7af4..2d5e1a9bdd 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -125,14 +125,14 @@ inline const OrtApi& GetApi() noexcept { return *Global::api_; } /// This function returns the onnxruntime version string /// /// version string major.minor.rev -std::basic_string GetVersionString(); +std::string GetVersionString(); /// /// This function returns the onnxruntime build information: including git branch, /// git commit id, build type(Debug/Release/RelWithDebInfo) and cmake cpp flags. /// /// string -std::basic_string GetBuildInfoString(); +std::string GetBuildInfoString(); /// /// This is a C++ wrapper for OrtApi::GetAvailableProviders() and diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 3e7cf721cd..b72bcd35fa 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1987,14 +1987,12 @@ inline void CustomOpApi::ReleaseKernelInfo(_Frees_ptr_opt_ OrtKernelInfo* info_c api_.ReleaseKernelInfo(info_copy); } -inline std::basic_string GetVersionString() { - std::basic_string result = OrtGetApiBase()->GetVersionString(); - return result; +inline std::string GetVersionString() { + return OrtGetApiBase()->GetVersionString(); } -inline std::basic_string GetBuildInfoString() { - std::basic_string result = OrtGetApiBase()->GetBuildInfoString(); - return result; +inline std::string GetBuildInfoString() { + return GetApi().GetBuildInfoString(); } inline std::vector GetAvailableProviders() { diff --git a/java/src/main/java/ai/onnxruntime/OrtProvider.java b/java/src/main/java/ai/onnxruntime/OrtProvider.java index cb35bf4f50..0da9487c67 100644 --- a/java/src/main/java/ai/onnxruntime/OrtProvider.java +++ b/java/src/main/java/ai/onnxruntime/OrtProvider.java @@ -23,7 +23,8 @@ public enum OrtProvider { ARM_NN("ArmNNExecutionProvider"), ROCM("ROCMExecutionProvider"), CORE_ML("CoreMLExecutionProvider"), - XNNPACK("XnnpackExecutionProvider"); + XNNPACK("XnnpackExecutionProvider"), + AZURE("AzureExecutionProvider"); private static final Map valueMap = new HashMap<>(values().length); diff --git a/java/src/main/native/ai_onnxruntime_OnnxRuntime.c b/java/src/main/native/ai_onnxruntime_OnnxRuntime.c index 5165f3499e..659f34e1fb 100644 --- a/java/src/main/native/ai_onnxruntime_OnnxRuntime.c +++ b/java/src/main/native/ai_onnxruntime_OnnxRuntime.c @@ -76,13 +76,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxRuntime_getAvailableProvi JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OnnxRuntime_initialiseVersion (JNIEnv * jniEnv, jclass clazz) { (void)clazz; // required JNI parameter not needed by functions which don't access their host class. - const ORTCHAR_T* version = OrtGetApiBase()->GetVersionString(); + const char* version = ORT_VERSION; assert(version != NULL); -#ifdef _WIN32 - jsize len = (jsize)(wcslen(version)); - jstring versionStr = (*jniEnv)->NewString(jniEnv, (const jchar*)version, len); -#else - jstring versionStr = (*jniEnv)->NewStringUTF(jniEnv, version); -#endif - return versionStr; + return (*jniEnv)->NewStringUTF(jniEnv, version); } diff --git a/js/common/package-lock.json b/js/common/package-lock.json index 1d956b5888..49816816a5 100644 --- a/js/common/package-lock.json +++ b/js/common/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-common", - "version": "1.15.0", + "version": "1.15.1", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-common", - "version": "1.15.0", + "version": "1.15.1", "license": "MIT", "devDependencies": { "typedoc": "^0.23.22" diff --git a/js/common/package.json b/js/common/package.json index 39e9e356e3..357fd7b099 100644 --- a/js/common/package.json +++ b/js/common/package.json @@ -8,7 +8,7 @@ }, "author": "fs-eire", "module": "dist/lib/index.js", - "version": "1.15.0", + "version": "1.15.1", "jsdelivr": "dist/ort-common.min.js", "scripts": { "prepare": "tsc && webpack" diff --git a/js/node/package-lock.json b/js/node/package-lock.json index 688e434121..8c9b3f049b 100644 --- a/js/node/package-lock.json +++ b/js/node/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-node", - "version": "1.15.0", + "version": "1.15.1", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-node", - "version": "1.15.0", + "version": "1.15.1", "license": "MIT", "os": [ "win32", @@ -28,7 +28,8 @@ } }, "../common": { - "version": "1.15.0", + "name": "onnxruntime-common", + "version": "1.15.1", "license": "MIT", "devDependencies": { "typedoc": "^0.23.22" diff --git a/js/node/package.json b/js/node/package.json index 8bad9d2a21..33d83edb9a 100644 --- a/js/node/package.json +++ b/js/node/package.json @@ -13,7 +13,7 @@ 3 ] }, - "version": "1.15.0", + "version": "1.15.1", "dependencies": { "onnxruntime-common": "file:../common" }, diff --git a/js/package-lock.json b/js/package-lock.json index deb97d1a07..4163656b34 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -2403,8 +2403,8 @@ } }, "node_modules/fastq": { - "version": "1.15.0", - "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.15.0.tgz", + "version": "1.15.1", + "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.15.1.tgz", "integrity": "sha512-wBrocU2LCXXa+lWBt8RoIRD89Fi8OdABODa/kEnyeyjS5aZO5/GNvI5sEINADqP/h8M29UHTHUb53sUu5Ihqdw==", "dev": true, "dependencies": { @@ -7532,8 +7532,8 @@ "dev": true }, "fastq": { - "version": "1.15.0", - "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.15.0.tgz", + "version": "1.15.1", + "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.15.1.tgz", "integrity": "sha512-wBrocU2LCXXa+lWBt8RoIRD89Fi8OdABODa/kEnyeyjS5aZO5/GNvI5sEINADqP/h8M29UHTHUb53sUu5Ihqdw==", "dev": true, "requires": { diff --git a/js/react_native/package.json b/js/react_native/package.json index 178b632bce..4a6efe6478 100644 --- a/js/react_native/package.json +++ b/js/react_native/package.json @@ -38,7 +38,7 @@ "registry": "https://registry.npmjs.org/" }, "source": "lib/index", - "version": "1.15.0", + "version": "1.15.1", "main": "dist/commonjs/index", "homepage": "https://github.com/microsoft/onnxruntime/blob/main/js/react_native/README.md", "files": [ diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock index d5925f65d2..c1ecd6754e 100644 --- a/js/react_native/yarn.lock +++ b/js/react_native/yarn.lock @@ -5197,7 +5197,7 @@ onetime@^5.1.0, onetime@^5.1.2: mimic-fn "^2.1.0" "onnxruntime-common@file:../common": - version "1.15.0" + version "1.15.1" open@^6.2.0: version "6.4.0" diff --git a/js/web/.gitignore b/js/web/.gitignore index 438068ff32..5b9e31ad5e 100644 --- a/js/web/.gitignore +++ b/js/web/.gitignore @@ -15,6 +15,8 @@ test/**/*.js.map script/**/*.js script/**/*.js.map +!/types.d.ts + lib/wasm/binding/**/*.wasm !lib/wasm/binding/**/*.d.ts diff --git a/js/web/.npmignore b/js/web/.npmignore index e21a906c9a..8e08db5917 100644 --- a/js/web/.npmignore +++ b/js/web/.npmignore @@ -4,8 +4,7 @@ /dist/**/*.report.html -/types/**/*.d.ts -!/types/lib/**/*.d.ts +/types/ karma.conf.js tsconfig.json diff --git a/js/web/docs/operators.md b/js/web/docs/operators.md index 115884cb9d..de84134ddb 100644 --- a/js/web/docs/operators.md +++ b/js/web/docs/operators.md @@ -19,7 +19,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [Asinh](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Asinh) | | | [Atan](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Atan) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Atan-7) | | [Atanh](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Atanh) | | -| [AveragePool](https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool) | [7-9](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#AveragePool-7), [10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#AveragePool-10), [11+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#AveragePool-11) | +| [AveragePool](https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool) | [7-9](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#AveragePool-7), [10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#AveragePool-10), [11-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#AveragePool-11), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#AveragePool-19) | | [BatchNormalization](https://github.com/onnx/onnx/blob/main/docs/Operators.md#BatchNormalization) | [7-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#BatchNormalization-7), [9-13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#BatchNormalization-9), [14](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#BatchNormalization-14), [15+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#BatchNormalization-15) | | [Bernoulli](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Bernoulli) | | | [BitShift](https://github.com/onnx/onnx/blob/main/docs/Operators.md#BitShift) | | @@ -28,7 +28,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [BitwiseOr](https://github.com/onnx/onnx/blob/main/docs/Operators.md#BitwiseOr) | | | [BitwiseXor](https://github.com/onnx/onnx/blob/main/docs/Operators.md#BitwiseXor) | | | [BlackmanWindow](https://github.com/onnx/onnx/blob/main/docs/Operators.md#BlackmanWindow) | | -| [Cast](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast) | [6-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-6), [9-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-9), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-13) | +| [Cast](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast) | [6-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-6), [9-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-9), [13-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-13), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-19) | | [CastLike](https://github.com/onnx/onnx/blob/main/docs/Operators.md#CastLike) | | | [Ceil](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Ceil) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Ceil-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Ceil-13) | | [Celu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Celu) | | @@ -47,6 +47,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [Cosh](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cosh) | | | [CumSum](https://github.com/onnx/onnx/blob/main/docs/Operators.md#CumSum) | | | [DFT](https://github.com/onnx/onnx/blob/main/docs/Operators.md#DFT) | | +| [DeformConv](https://github.com/onnx/onnx/blob/main/docs/Operators.md#DeformConv) | | | [DepthToSpace](https://github.com/onnx/onnx/blob/main/docs/Operators.md#DepthToSpace) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#DepthToSpace-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#DepthToSpace-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#DepthToSpace-13) | | [DequantizeLinear](https://github.com/onnx/onnx/blob/main/docs/Operators.md#DequantizeLinear) | | | [Det](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Det) | | @@ -55,7 +56,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [DynamicQuantizeLinear](https://github.com/onnx/onnx/blob/main/docs/Operators.md#DynamicQuantizeLinear) | | | [Einsum](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Einsum) | | | [Elu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Elu) | [6+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Elu-6) | -| [Equal](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Equal) | [7-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Equal-7), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Equal-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Equal-13) | +| [Equal](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Equal) | [7-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Equal-7), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Equal-11), [13-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Equal-13), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Equal-19) | | [Erf](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Erf) | | | [Exp](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Exp) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Exp-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Exp-13) | | [Expand](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Expand) | | @@ -79,7 +80,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [HardSigmoid](https://github.com/onnx/onnx/blob/main/docs/Operators.md#HardSigmoid) | | | [HardSwish](https://github.com/onnx/onnx/blob/main/docs/Operators.md#HardSwish) | | | [Hardmax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Hardmax) | | -| [Identity](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Identity) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-1), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-13), [14-15](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-14), [16+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-16) | +| [Identity](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Identity) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-1), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-13), [14-15](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-14), [16-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-16), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-19) | | [If](https://github.com/onnx/onnx/blob/main/docs/Operators.md#If) | | | [InstanceNormalization](https://github.com/onnx/onnx/blob/main/docs/Operators.md#InstanceNormalization) | [6+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#InstanceNormalization-6) | | [IsInf](https://github.com/onnx/onnx/blob/main/docs/Operators.md#IsInf) | | @@ -120,7 +121,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [OptionalHasElement](https://github.com/onnx/onnx/blob/main/docs/Operators.md#OptionalHasElement) | | | [Or](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Or) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Or-7) | | [PRelu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#PRelu) | [7-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#PRelu-7), [9-15](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#PRelu-9), [16+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#PRelu-16) | -| [Pad](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Pad) | [2-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-2), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-18) | +| [Pad](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Pad) | [2-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-2), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-13), [18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-18), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-19) | | [Pow](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Pow) | [7-11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pow-7), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pow-12), [13-14](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pow-13), [15+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pow-15) | | [QLinearConv](https://github.com/onnx/onnx/blob/main/docs/Operators.md#QLinearConv) | | | [QLinearMatMul](https://github.com/onnx/onnx/blob/main/docs/Operators.md#QLinearMatMul) | | @@ -143,8 +144,8 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [ReduceSum](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSum) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSum-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSum-11) | | [ReduceSumSquare](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSumSquare) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-18) | | [Relu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Relu) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-6), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-13), [14+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-14) | -| [Reshape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Reshape) | [5-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-5), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-13), [14+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-14) | -| [Resize](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize) | [10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-10), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-18) | +| [Reshape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Reshape) | [5-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-5), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-13), [14-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-14), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-19) | +| [Resize](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize) | [10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-10), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-13), [18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-18), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-19) | | [ReverseSequence](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReverseSequence) | | | [RoiAlign](https://github.com/onnx/onnx/blob/main/docs/Operators.md#RoiAlign) | | | [Round](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Round) | | @@ -161,7 +162,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [SequenceInsert](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SequenceInsert) | | | [SequenceLength](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SequenceLength) | | | [SequenceMap](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SequenceMap) | | -| [Shape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-1), [13-14](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-13), [15+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-15) | +| [Shape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-1), [13-14](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-13), [15-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-15), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-19) | | [Shrink](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shrink) | | | [Sigmoid](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sigmoid) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sigmoid-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sigmoid-13) | | [Sign](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sign) | | diff --git a/js/web/karma.conf.js b/js/web/karma.conf.js index 2a4e71e064..860a7d2e20 100644 --- a/js/web/karma.conf.js +++ b/js/web/karma.conf.js @@ -63,6 +63,8 @@ module.exports = function (config) { { pattern: 'dist/ort-wasm-threaded.wasm', included: false }, { pattern: 'dist/ort-wasm-simd.wasm', included: false }, { pattern: 'dist/ort-wasm-simd-threaded.wasm', included: false }, + { pattern: 'dist/ort-wasm-simd.jsep.wasm', included: false }, + { pattern: 'dist/ort-wasm-simd-threaded.jsep.wasm', included: false }, { pattern: 'dist/ort-wasm-threaded.worker.js', included: false }, ], proxies: { @@ -70,6 +72,8 @@ module.exports = function (config) { '/base/test/ort-wasm-threaded.wasm': '/base/dist/ort-wasm-threaded.wasm', '/base/test/ort-wasm-simd.wasm': '/base/dist/ort-wasm-simd.wasm', '/base/test/ort-wasm-simd-threaded.wasm': '/base/dist/ort-wasm-simd-threaded.wasm', + '/base/test/ort-wasm-simd.jsep.wasm': '/base/dist/ort-wasm-simd.jsep.wasm', + '/base/test/ort-wasm-simd-threaded.jsep.wasm': '/base/dist/ort-wasm-simd-threaded.jsep.wasm', '/base/test/ort-wasm-threaded.worker.js': '/base/dist/ort-wasm-threaded.worker.js', }, plugins: karmaPlugins, diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index 683e15e884..7648f0c473 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -6,11 +6,16 @@ import * as path from 'path'; import {OrtWasmModule} from './binding/ort-wasm'; import {OrtWasmThreadedModule} from './binding/ort-wasm-threaded'; -import ortWasmFactory from './binding/ort-wasm.js'; -const ortWasmFactoryThreaded: EmscriptenModuleFactory = - // eslint-disable-next-line @typescript-eslint/no-require-imports - !BUILD_DEFS.DISABLE_WASM_THREAD ? require('./binding/ort-wasm-threaded.js') : ortWasmFactory; +/* eslint-disable @typescript-eslint/no-require-imports */ +const ortWasmFactory: EmscriptenModuleFactory = + BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js'); + +const ortWasmFactoryThreaded: EmscriptenModuleFactory = !BUILD_DEFS.DISABLE_WASM_THREAD ? + (BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm-threaded.js') : + require('./binding/ort-wasm-simd-threaded.jsep.js')) : + ortWasmFactory; +/* eslint-enable @typescript-eslint/no-require-imports */ let wasm: OrtWasmModule|undefined; let initialized = false; @@ -95,10 +100,10 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise const useThreads = numThreads > 1 && isMultiThreadSupported(); const useSimd = simd && isSimdSupported(); - const wasmPrefixOverride = typeof flags.wasmPaths === 'string' ? flags.wasmPaths : undefined; - const wasmFileName = getWasmFileName(false, useThreads); - const wasmOverrideFileName = getWasmFileName(useSimd, useThreads); - const wasmPathOverride = typeof flags.wasmPaths === 'object' ? flags.wasmPaths[wasmOverrideFileName] : undefined; + const wasmPaths = flags.wasmPaths; + const wasmPrefixOverride = typeof wasmPaths === 'string' ? wasmPaths : undefined; + const wasmFileName = getWasmFileName(useSimd, useThreads); + const wasmPathOverride = typeof wasmPaths === 'object' ? wasmPaths[wasmFileName] : undefined; let isTimeout = false; @@ -130,9 +135,22 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise {type: 'text/javascript'})); } - if (fileName === wasmFileName) { - const prefix: string = wasmPrefixOverride ?? scriptDirectory; - return wasmPathOverride ?? prefix + wasmOverrideFileName; + if (fileName.endsWith('.wasm')) { + if (wasmPathOverride) { + return wasmPathOverride; + } + + const prefix = wasmPrefixOverride ?? scriptDirectory; + + if (!BUILD_DEFS.DISABLE_WEBGPU) { + if (wasmFileName === 'ort-wasm-simd.wasm') { + return prefix + 'ort-wasm-simd.jsep.wasm'; + } else if (wasmFileName === 'ort-wasm-simd-threaded.wasm') { + return prefix + 'ort-wasm-simd-threaded.jsep.wasm'; + } + } + + return prefix + wasmFileName; } return scriptDirectory + fileName; diff --git a/js/web/package-lock.json b/js/web/package-lock.json index ad44566290..dd18e3eac0 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-web", - "version": "1.15.0", + "version": "1.15.1", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-web", - "version": "1.15.0", + "version": "1.15.1", "license": "MIT", "dependencies": { "flatbuffers": "^1.12.0", @@ -52,7 +52,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.15.0", + "version": "1.15.1", "license": "MIT", "devDependencies": { "typedoc": "^0.23.22" @@ -1158,9 +1158,9 @@ } }, "node_modules/engine.io": { - "version": "6.4.1", - "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.4.1.tgz", - "integrity": "sha512-JFYQurD/nbsA5BSPmbaOSLa3tSVj8L6o4srSwXXY3NqE+gGUNmmPTbhn8tjzcCtSqhFgIeqef81ngny8JM25hw==", + "version": "6.4.2", + "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.4.2.tgz", + "integrity": "sha512-FKn/3oMiJjrOEOeUub2WCox6JhxBXq/Zn3fZOMCBxKnNYtsdKjxhl7yR3fZhM9PV+rdE75SU5SYMc+2PGzo+Tg==", "dev": true, "dependencies": { "@types/cookie": "^0.4.1", @@ -4933,9 +4933,9 @@ } }, "engine.io": { - "version": "6.4.1", - "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.4.1.tgz", - "integrity": "sha512-JFYQurD/nbsA5BSPmbaOSLa3tSVj8L6o4srSwXXY3NqE+gGUNmmPTbhn8tjzcCtSqhFgIeqef81ngny8JM25hw==", + "version": "6.4.2", + "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.4.2.tgz", + "integrity": "sha512-FKn/3oMiJjrOEOeUub2WCox6JhxBXq/Zn3fZOMCBxKnNYtsdKjxhl7yR3fZhM9PV+rdE75SU5SYMc+2PGzo+Tg==", "dev": true, "requires": { "@types/cookie": "^0.4.1", diff --git a/js/web/package.json b/js/web/package.json index 9b18dcbd23..0de6a2ef1f 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -8,7 +8,7 @@ "type": "git" }, "author": "fs-eire", - "version": "1.15.0", + "version": "1.15.1", "jsdelivr": "dist/ort.min.js", "dependencies": { "flatbuffers": "^1.12.0", @@ -66,6 +66,22 @@ "strip-json-comments": "^5.0.0" }, "main": "dist/ort-web.node.js", - "types": "./types/lib/index.d.ts", + "exports": { + ".": { + "node": { + "types": "./types.d.ts", + "default": "./dist/ort-web.node.js" + }, + "default": { + "types": "./types.d.ts", + "default": "./dist/ort.min.js" + } + }, + "./webgpu": { + "types": "./types.d.ts", + "default": "./dist/ort.webgpu.min.js" + } + }, + "types": "./types.d.ts", "description": "A Javascript library for running ONNX models on browsers" } diff --git a/js/web/script/build.ts b/js/web/script/build.ts index ebdce462b9..d3a5be429b 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -34,8 +34,11 @@ const ROOT_FOLDER = path.join(__dirname, '..'); const WASM_BINDING_FOLDER = path.join(ROOT_FOLDER, 'lib', 'wasm', 'binding'); const WASM_BINDING_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm.js'); const WASM_BINDING_THREADED_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.js'); +const WASM_BINDING_SIMD_THREADED_JSEP_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-simd-threaded.jsep.js'); const WASM_BINDING_THREADED_WORKER_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.worker.js'); const WASM_BINDING_THREADED_MIN_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.min.js'); +const WASM_BINDING_SIMD_THREADED_JSEP_MIN_JS_PATH = + path.join(WASM_BINDING_FOLDER, 'ort-wasm-simd-threaded.jsep.min.js'); const WASM_BINDING_THREADED_MIN_WORKER_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.min.worker.js'); const WASM_DIST_FOLDER = path.join(ROOT_FOLDER, 'dist'); @@ -43,8 +46,11 @@ const WASM_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm.wasm'); const WASM_THREADED_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-threaded.wasm'); const WASM_SIMD_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd.wasm'); const WASM_SIMD_THREADED_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd-threaded.wasm'); +const WASM_SIMD_JSEP_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd.jsep.wasm'); +const WASM_SIMD_THREADED_JSEP_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd-threaded.jsep.wasm'); const WASM_THREADED_WORKER_JS_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-threaded.worker.js'); const WASM_THREADED_JS_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-threaded.js'); +const WASM_SIMD_THREADED_JSEP_JS_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd-threaded.jsep.js'); function validateFile(path: string): void { npmlog.info('Build', `Ensure file: ${path}`); @@ -61,11 +67,14 @@ if (WASM) { try { validateFile(WASM_BINDING_JS_PATH); validateFile(WASM_BINDING_THREADED_JS_PATH); + validateFile(WASM_BINDING_SIMD_THREADED_JSEP_JS_PATH); validateFile(WASM_BINDING_THREADED_WORKER_JS_PATH); validateFile(WASM_WASM_PATH); validateFile(WASM_THREADED_WASM_PATH); validateFile(WASM_SIMD_WASM_PATH); validateFile(WASM_SIMD_THREADED_WASM_PATH); + validateFile(WASM_SIMD_JSEP_WASM_PATH); + validateFile(WASM_SIMD_THREADED_JSEP_WASM_PATH); } catch (e) { npmlog.error('Build', `WebAssembly files are not ready. build WASM first. ERR: ${e}`); throw e; @@ -86,7 +95,7 @@ if (WASM) { 'npx', [ 'terser', WASM_BINDING_THREADED_JS_PATH, '--compress', 'passes=2', '--format', 'comments=false', '--mangle', - 'reserved=[_scriptDir]', '--module' + 'reserved=[_scriptDir,startWorker]', '--module' ], {shell: true, encoding: 'utf-8', cwd: ROOT_FOLDER}); if (terser.status !== 0) { @@ -105,13 +114,38 @@ if (WASM) { } npmlog.info('Build', 'Minimizing file "ort-wasm-threaded.js"... DONE'); + npmlog.info('Build', 'Minimizing file "ort-wasm-simd-threaded.jsep.js"...'); + try { + const terser = spawnSync( + 'npx', + [ + 'terser', WASM_BINDING_SIMD_THREADED_JSEP_JS_PATH, '--compress', 'passes=2', '--format', 'comments=false', + '--mangle', 'reserved=[_scriptDir,startWorker]', '--module' + ], + {shell: true, encoding: 'utf-8', cwd: ROOT_FOLDER}); + if (terser.status !== 0) { + console.error(terser.error); + process.exit(terser.status === null ? undefined : terser.status); + } + + fs.writeFileSync(WASM_BINDING_SIMD_THREADED_JSEP_MIN_JS_PATH, terser.stdout); + fs.writeFileSync(WASM_SIMD_THREADED_JSEP_JS_PATH, `${COPYRIGHT_BANNER}${terser.stdout}`); + + validateFile(WASM_BINDING_SIMD_THREADED_JSEP_MIN_JS_PATH); + validateFile(WASM_SIMD_THREADED_JSEP_JS_PATH); + } catch (e) { + npmlog.error('Build', `Failed to run terser on ort-wasm-threaded.js. ERR: ${e}`); + throw e; + } + npmlog.info('Build', 'Minimizing file "ort-wasm-simd-threaded.jsep.js"... DONE'); + npmlog.info('Build', 'Minimizing file "ort-wasm-threaded.worker.js"...'); try { const terser = spawnSync( 'npx', [ 'terser', WASM_BINDING_THREADED_WORKER_JS_PATH, '--compress', 'passes=2', '--format', 'comments=false', - '--mangle', 'reserved=[_scriptDir]', '--toplevel' + '--mangle', 'reserved=[_scriptDir,startWorker]', '--toplevel' ], {shell: true, encoding: 'utf-8'}); if (terser.status !== 0) { diff --git a/js/web/script/pull-prebuilt-wasm-artifacts.ts b/js/web/script/pull-prebuilt-wasm-artifacts.ts index bd6fa1af4f..8c2f24cbf7 100644 --- a/js/web/script/pull-prebuilt-wasm-artifacts.ts +++ b/js/web/script/pull-prebuilt-wasm-artifacts.ts @@ -112,10 +112,14 @@ downloadJson( extractFile(zip, WASM_FOLDER, 'ort-wasm-threaded.wasm', 'Release_wasm'); extractFile(zip, WASM_FOLDER, 'ort-wasm-simd.wasm', 'Release_wasm'); extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.wasm', 'Release_wasm'); + extractFile(zip, WASM_FOLDER, 'ort-wasm-simd.jsep.wasm', 'Release_wasm'); + extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.jsep.wasm', 'Release_wasm'); extractFile(zip, JS_FOLDER, 'ort-wasm.js', 'Release_wasm'); extractFile(zip, JS_FOLDER, 'ort-wasm-threaded.js', 'Release_wasm'); extractFile(zip, JS_FOLDER, 'ort-wasm-threaded.worker.js', 'Release_wasm'); + extractFile(zip, JS_FOLDER, 'ort-wasm-simd.jsep.js', 'Release_wasm'); + extractFile(zip, JS_FOLDER, 'ort-wasm-simd-threaded.jsep.js', 'Release_wasm'); }); }); }); diff --git a/js/web/types.d.ts b/js/web/types.d.ts new file mode 100644 index 0000000000..c6cff64c8a --- /dev/null +++ b/js/web/types.d.ts @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +declare module 'onnxruntime-web' { + export * from 'onnxruntime-common'; +} + +declare module 'onnxruntime-web/webgpu' { + export * from 'onnxruntime-web'; +} diff --git a/js/web/webpack.config.js b/js/web/webpack.config.js index 1c842ddced..84ab9bcae8 100644 --- a/js/web/webpack.config.js +++ b/js/web/webpack.config.js @@ -49,7 +49,7 @@ function defaultTerserPluginOptions(target) { passes: 2 }, mangle: { - reserved: ["_scriptDir"] + reserved: ["_scriptDir","startWorker"] } } }; @@ -57,7 +57,7 @@ function defaultTerserPluginOptions(target) { const DEFAULT_BUILD_DEFS = { DISABLE_WEBGL: false, - DISABLE_WEBGPU: false, + DISABLE_WEBGPU: true, DISABLE_WASM: false, DISABLE_WASM_PROXY: false, DISABLE_WASM_THREAD: false, @@ -123,6 +123,7 @@ function buildConfig({ filename, format, target, mode, devtool, build_defs }) { if (mode === 'production') { config.resolve.alias['./binding/ort-wasm-threaded.js'] = './binding/ort-wasm-threaded.min.js'; + config.resolve.alias['./binding/ort-wasm-threaded-simd.jsep.js'] = './binding/ort-wasm-threaded-simd.jsep.min.js'; config.resolve.alias['./binding/ort-wasm-threaded.worker.js'] = './binding/ort-wasm-threaded.min.worker.js'; const options = defaultTerserPluginOptions(target); @@ -211,7 +212,8 @@ function buildTestRunnerConfig({ format = 'umd', target = 'es2017', mode = 'production', - devtool = 'source-map' + devtool = 'source-map', + build_defs = DEFAULT_BUILD_DEFS }) { const config = { target: ['web', target], @@ -244,7 +246,7 @@ function buildTestRunnerConfig({ } }, plugins: [ - new webpack.DefinePlugin({ BUILD_DEFS: DEFAULT_BUILD_DEFS }), + new webpack.DefinePlugin({ BUILD_DEFS: build_defs }), new webpack.WatchIgnorePlugin({ paths: [/\.js$/, /\.d\.ts$/] }), new NodePolyfillPlugin({ excludeAliases: ["console", "Buffer"] @@ -315,6 +317,13 @@ module.exports = () => { DISABLE_WASM_THREAD: true, } }), + // ort.webgpu.min.js + buildOrtConfig({ + suffix: '.webgpu.min', build_defs: { + ...DEFAULT_BUILD_DEFS, + DISABLE_WEBGPU: false, + } + }), // ort-web.min.js buildOrtWebConfig({ suffix: '.min' }), @@ -333,10 +342,20 @@ module.exports = () => { ); break; case 'dev': - builds.push(buildTestRunnerConfig({ suffix: '.dev', mode: 'development', devtool: 'inline-source-map' })); + builds.push(buildTestRunnerConfig({ + suffix: '.dev', mode: 'development', devtool: 'inline-source-map', build_defs: { + ...DEFAULT_BUILD_DEFS, + DISABLE_WEBGPU: false, + } + })); break; case 'perf': - builds.push(buildTestRunnerConfig({ suffix: '.perf' })); + builds.push(buildTestRunnerConfig({ + suffix: '.perf', build_defs: { + ...DEFAULT_BUILD_DEFS, + DISABLE_WEBGPU: false, + } + })); break; default: throw new Error(`unsupported bundle mode: ${bundleMode}`); diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 1da5824cf7..14aa188ac0 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -7,7 +7,7 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime `_ or the `Github project `_. """ -__version__ = "1.15.0" +__version__ = "1.15.1" __author__ = "Microsoft" # we need to do device version validation (for example to check Cuda version for an onnxruntime-training package). diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 31b1b12d38..90f5abab87 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -48,6 +48,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Quick // ******** Start: Quantization ******************* // class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearGlobalAveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearConvTranspose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearConcat); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearWhere); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearAveragePool); @@ -181,6 +182,7 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/tokenizer.cc b/onnxruntime/contrib_ops/cpu/tokenizer.cc index 45998b6d83..1787fb9b3c 100644 --- a/onnxruntime/contrib_ops/cpu/tokenizer.cc +++ b/onnxruntime/contrib_ops/cpu/tokenizer.cc @@ -242,7 +242,7 @@ Status Tokenizer::SeparatorExpressionTokenizer(OpKernelContext* ctx, token_len, utf8_chars); if (!valid) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Match contains invalid utf8 chars: " + submatch.as_string()); + "Match contains invalid utf8 chars: " + std::string{submatch}); } if (utf8_chars >= size_t(mincharnum_)) { tokens.emplace_back(text.data() + start_pos, token_len); @@ -384,7 +384,7 @@ Status Tokenizer::TokenExpression(OpKernelContext* ctx, utf8_chars = 0; if (!utf8_len(reinterpret_cast(submatch.data()), token_len, utf8_chars)) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Match contains invalid utf8 chars: " + submatch.as_string()); + "Match contains invalid utf8 chars: " + std::string{submatch}); } if (utf8_chars >= size_t(mincharnum_)) { row.push_back(submatch); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index def1508ca2..bd1498c0f9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -164,6 +164,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { has_memory_efficient_attention(sm, sizeof(T) == 2); #else constexpr bool use_memory_efficient_attention = false; + ORT_UNUSED_VARIABLE(is_mask_1d_key_seq_len_start); #endif cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu index b5c8dcca32..ebe87158d1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu @@ -156,7 +156,7 @@ struct TypeMapper : public V_vec_m_ {}; // The following operator overriding is not common so we put it in anonymous namespace #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ > 530 inline __device__ half2 operator*(const float a, const half2 b) { - return __hmul2_rn(__float2half2_rn(a), b); + return __hmul2(__float2half2_rn(a), b); } #else inline __device__ half2 operator*(const float a, const half2 b) { diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index a6e040812a..1334708aad 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -1324,6 +1324,7 @@ class PlannerImpl { #endif ORT_RETURN_IF_ERROR(ComputeSingleStreamReusePlan(i)); ClearUseCount(); + freelist_.clear(); // DONOT share freelist across streams } #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) CalculateLifetime(ort_value_usecount); @@ -1745,9 +1746,8 @@ class PlannerImpl { #else - void - PartitionIntoStreams(const logging::Logger& logger, const ExecutionProviders& execution_providers, - const PathString& partition_config_file) { + void PartitionIntoStreams(const logging::Logger& logger, const ExecutionProviders& execution_providers, + const PathString& partition_config_file) { auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger, partition_config_file); auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_, context_->GetExecutionOrder()); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); @@ -1760,7 +1760,7 @@ class PlannerImpl { num_logic_streams_ = stream_nodes_.size(); } - // build each logic streams + // Build each logic streams Status BuildExecutionPlan(const ExecutionProviders& execution_providers, const IStreamCommandHandleRegistry& stream_handle_registry) { // 1. create logic stream instance @@ -1780,12 +1780,12 @@ class PlannerImpl { execution_plan.emplace_back(nullptr); } } - // 2. determing following things: - // a. which node need to generate notification - // b. which node need to trigger downstream + // 2. Determining following things: + // a. which node needs to generate the notification + // b. which node needs to trigger downstream #ifdef ENABLE_TRAINING // We will leverage the topological order for the training scenario. - // The nodes before yieldOp in topo order will be executed in RunForward() and nodes after will be executed in RunBackward() + // The nodes before yieldOp in topo-order will be executed in RunForward() and nodes after will be executed in RunBackward() // This partition may not be exactly the same as forward model/gradient model, for example, some nodes in gradient model are // before yieldOp thus will be executed in RunForward() // But the final result is still correct, as long as all the nodes will be executed in either RunForward() or RunBackward() @@ -1820,7 +1820,7 @@ class PlannerImpl { if (node_stream_map_[it->Index()] != i #ifdef ENABLE_TRAINING // Do not insert Barrier/TriggerDownStream step if the producer and consumer are in different sides of yieldOp - // As in this case producer will surely be ready before consumer is running. + // As in this case producer will surely be ready before the consumer is running. && !AreNodesSeparatedByYield(node_index, it->Index()) #endif ) { @@ -2048,8 +2048,7 @@ class PlannerImpl { } #endif - static bool - IsNonTensor(const onnxruntime::NodeArg& nodearg) { + static bool IsNonTensor(const onnxruntime::NodeArg& nodearg) { // TODO: unclear why we should go through a string-representation of type auto ptype = nodearg.Type(); auto& type_proto = ONNX_NAMESPACE::Utils::DataTypeUtils::ToTypeProto(ptype); diff --git a/onnxruntime/core/framework/cloud_invoker.cc b/onnxruntime/core/framework/cloud_invoker.cc index d6883d408f..a2e9ec97cf 100644 --- a/onnxruntime/core/framework/cloud_invoker.cc +++ b/onnxruntime/core/framework/cloud_invoker.cc @@ -2,7 +2,9 @@ // Licensed under the MIT License. #ifdef USE_AZURE +#define CURL_STATICLIB #include "http_client.h" +#include "curl/curl.h" #include "core/common/common.h" #include "core/framework/cloud_invoker.h" #include "core/framework/ort_value.h" @@ -18,13 +20,14 @@ namespace onnxruntime { namespace tc = triton::client; -const char* kAzureUri = "azure.uri"; -const char* kAzureModelName = "azure.model_name"; -const char* kAzureModelVer = "azure.model_version"; -const char* kAzureVerbose = "azure.verbose"; -const char* kAzureEndpointType = "azure.endpoint_type"; -const char* kAzureAuthKey = "azure.auth_key"; -const char* kAzureTriton = "triton"; +constexpr const char* kAzureUri = "azure.uri"; +constexpr const char* kAzureModelName = "azure.model_name"; +constexpr const char* kAzureModelVer = "azure.model_version"; +constexpr const char* kAzureVerbose = "azure.verbose"; +constexpr const char* kAzureEndpointType = "azure.endpoint_type"; +constexpr const char* kAzureAuthKey = "azure.auth_key"; +constexpr const char* kAzureTriton = "triton"; +constexpr const char* kAzureOpenAI = "openai"; CloudEndPointInvoker::CloudEndPointInvoker(const CloudEndPointConfig& config, const AllocatorPtr& allocator) : config_(config), allocator_(allocator) { @@ -33,6 +36,163 @@ CloudEndPointInvoker::CloudEndPointInvoker(const CloudEndPointConfig& config, } } +class CurlGlobal { + public: + static void Init() { + static CurlGlobal curl_global; + } + + private: + CurlGlobal() { + // Thread-safety is a must since curl might also be initialized in triton client. + const auto* info = curl_version_info(CURLVERSION_NOW); + ORT_ENFORCE(info->features & CURL_VERSION_THREADSAFE, "curl global init not thread-safe, need to upgrade curl version!"); + ORT_ENFORCE(curl_global_init(CURL_GLOBAL_DEFAULT) == CURLE_OK, "Failed to initialize curl global env!"); + } + ~CurlGlobal() { + curl_global_cleanup(); + } + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CurlGlobal); +}; + +// OpenAIInvoker +class OpenAIInvoker : public CloudEndPointInvoker { + public: + OpenAIInvoker(const CloudEndPointConfig& config, const AllocatorPtr& allocator); + onnxruntime::Status Send(const CloudEndPointConfig& run_options, + const InlinedVector& input_names, + gsl::span ort_inputs, + const InlinedVector& output_names, + std::vector& ort_outputs) const override; + + private: + std::string uri_; + std::string model_name_; +}; + +OpenAIInvoker::OpenAIInvoker(const CloudEndPointConfig& config, + const AllocatorPtr& allocator) : CloudEndPointInvoker(config, allocator) { + CurlGlobal::Init(); + ReadConfig(kAzureUri, uri_); + ReadConfig(kAzureModelName, model_name_); +} + +struct StringBuffer { + StringBuffer() = default; + ~StringBuffer() = default; + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(StringBuffer); + std::stringstream ss_; +}; + +// apply the callback only when response is for sure to be a '/0' terminated string +static size_t WriteStringCallback(void* contents, size_t size, size_t nmemb, void* userp) { + try { + size_t realsize = size * nmemb; + auto buffer = reinterpret_cast(userp); + buffer->ss_.write(reinterpret_cast(contents), realsize); + return realsize; + } catch (...) { + // exception caught, abort write + return CURLcode::CURLE_WRITE_ERROR; + } +} + +using CurlWriteCallBack = size_t (*)(void*, size_t, size_t, void*); + +class CurlHandler { + public: + CurlHandler(CurlWriteCallBack call_back) : curl_(curl_easy_init(), curl_easy_cleanup), + headers_(nullptr, curl_slist_free_all), + from_holder_(from_, curl_formfree) { + curl_easy_setopt(curl_.get(), CURLOPT_BUFFERSIZE, 102400L); + curl_easy_setopt(curl_.get(), CURLOPT_NOPROGRESS, 1L); + curl_easy_setopt(curl_.get(), CURLOPT_USERAGENT, "curl/7.83.1"); + curl_easy_setopt(curl_.get(), CURLOPT_MAXREDIRS, 50L); + curl_easy_setopt(curl_.get(), CURLOPT_FTP_SKIP_PASV_IP, 1L); + curl_easy_setopt(curl_.get(), CURLOPT_TCP_KEEPALIVE, 1L); + curl_easy_setopt(curl_.get(), CURLOPT_WRITEFUNCTION, call_back); + } + ~CurlHandler() = default; + + void AddHeader(const char* data) { + headers_.reset(curl_slist_append(headers_.release(), data)); + } + template + void AddForm(Args... args) { + curl_formadd(&from_, &last_, args...); + } + template + void SetOption(CURLoption opt, T val) { + curl_easy_setopt(curl_.get(), opt, val); + } + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CurlHandler); + CURLcode Perform() { + SetOption(CURLOPT_HTTPHEADER, headers_.get()); + SetOption(CURLOPT_HTTPPOST, from_); + return curl_easy_perform(curl_.get()); + } + + private: + std::unique_ptr curl_; + std::unique_ptr headers_; + curl_httppost* from_{}; + curl_httppost* last_{}; + std::unique_ptr from_holder_; +}; + +onnxruntime::Status OpenAIInvoker::Send(const CloudEndPointConfig& run_options, + const InlinedVector& /*input_names*/, + gsl::span ort_inputs, + const InlinedVector& /*output_names*/, + std::vector& ort_outputs) const { + const auto auth_key_iter = run_options.find(kAzureAuthKey); + if (run_options.end() == auth_key_iter || auth_key_iter->second.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "auth key must be specified for openai client"); + } + long verbose = 0; + const auto verbose_iter = run_options.find(kAzureVerbose); + if (run_options.end() != verbose_iter) { + verbose = verbose_iter->second != "0" ? 1L : 0L; + } + + CurlHandler curl_handler(WriteStringCallback); + StringBuffer string_buffer; + + std::string full_auth = std::string{"Authorization: Bearer "} + auth_key_iter->second; + curl_handler.AddHeader(full_auth.c_str()); + curl_handler.AddHeader("Content-Type: multipart/form-data"); + + const auto& tensor = ort_inputs[0].Get(); + auto data_size = tensor.SizeInBytes(); + curl_handler.AddForm(CURLFORM_COPYNAME, "model", CURLFORM_COPYCONTENTS, model_name_.c_str(), CURLFORM_END); + curl_handler.AddForm(CURLFORM_COPYNAME, "response_format", CURLFORM_COPYCONTENTS, "text", CURLFORM_END); + curl_handler.AddForm(CURLFORM_COPYNAME, "file", CURLFORM_BUFFER, "non_exist.wav", CURLFORM_BUFFERPTR, tensor.DataRaw(), + CURLFORM_BUFFERLENGTH, data_size, CURLFORM_END); + + curl_handler.SetOption(CURLOPT_URL, uri_.c_str()); + curl_handler.SetOption(CURLOPT_VERBOSE, verbose); + curl_handler.SetOption(CURLOPT_WRITEDATA, (void*)&string_buffer); + + auto curl_ret = curl_handler.Perform(); + if (CURLE_OK != curl_ret) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, curl_easy_strerror(curl_ret)); + } + + auto output_tensor = std::make_unique(onnxruntime::DataTypeImpl::GetType(), TensorShape{1}, allocator_); + if (!output_tensor) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create output tensor"); + } + + auto* output_string = output_tensor->MutableData(); + *output_string = string_buffer.ss_.str(); + auto tensor_type = DataTypeImpl::GetType(); + ort_outputs.clear(); + ort_outputs.emplace_back(output_tensor.release(), tensor_type, tensor_type->GetDeleteFunc()); + return Status::OK(); +} + +// AzureTritonInvoker class AzureTritonInvoker : public CloudEndPointInvoker { public: AzureTritonInvoker(const CloudEndPointConfig& config, const AllocatorPtr& allocator); @@ -287,6 +447,9 @@ Status CloudEndPointInvoker::CreateInvoker(const CloudEndPointConfig& config, if (iter->second == kAzureTriton) { invoker = std::make_unique(config, allocator); return status; + } else if (iter->second == kAzureOpenAI) { + invoker = std::make_unique(config, allocator); + return status; } // else other endpoint types ... } status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index 00aeff37d7..d9c49dc6be 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -28,6 +28,7 @@ using namespace onnxruntime::common; namespace onnxruntime { #ifdef ORT_ENABLE_STREAM static StreamAwareArena* AsStreamBasedAllocator(AllocatorPtr allocator) { + ORT_ENFORCE(allocator.get() != nullptr, "allocator is nullptr"); if (allocator->Info().alloc_type == OrtArenaAllocator) { BFCArena* arena_ptr = static_cast(allocator.get()); return StreamAwareArena::FromBFCArena(*arena_ptr); @@ -137,7 +138,7 @@ Status IExecutionFrame::GetOutputs(gsl::span fetch_mlvalue_idxs, std: #endif -// Return nullptr if index map to an value that is an unused optional input/output +// Return nullptr if index map to a value that is an unused optional input/output const OrtValue* IExecutionFrame::GetNodeInputOrOutputMLValue(int index) const { int ort_value_idx = GetNodeIdxToMLValueIdx(index); return ort_value_idx != NodeIndexInfo::kInvalidEntry ? &(all_values_[ort_value_idx]) : nullptr; @@ -147,9 +148,9 @@ OrtValue* IExecutionFrame::GetMutableNodeInputOrOutputMLValue(int index) { return const_cast(GetNodeInputOrOutputMLValue(index)); } -// TO DO: make it thread safe -// This method is not thread safe! -// Return S_OK and nullptr if index map to an value that is an unused optional input/output +// TO DO: make it thread-safe +// This method is not thread-safe! +// Return S_OK and nullptr if index map to a value that is an unused optional input/output Status IExecutionFrame::GetOrCreateNodeOutputMLValue(const int output_index, int output_arg_index, const TensorShape* shape, OrtValue*& p_ort_value, @@ -191,7 +192,7 @@ Status IExecutionFrame::GetOrCreateNodeOutputMLValue(const int output_index, int } bool IExecutionFrame::TryGetInferredShape(int /*index*/, TensorShape& /*shape*/) const { - // By default, there is not information about inferred shape, so this default + // By default, there is no information about inferred shape, so this default // implementation always returns false. The derived class of IExecutionFrame // can override this function to provide, for example, activations' shape information. return false; @@ -213,7 +214,7 @@ Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) { } int IExecutionFrame::GetNodeIdxToMLValueIdx(int index) const { - // the validity of index is checked by GetMLValueIndex + // The validity of the index is checked by GetMLValueIndex int ort_value_idx = node_index_info_.GetMLValueIndex(index); return ort_value_idx; } @@ -241,7 +242,7 @@ void IExecutionFrame::Init(gsl::span feed_mlvalue_idxs, gsl::span feed_mlvalue_idxs, gsl::span feed_mlvalue_idxs, gsl::span feed_mlvalue_idxs, gsl::span // Planning of one memory type should only happen once. #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) ORT_ENFORCE( - static_activation_memory_sizes_in_byte_.find(location.name) == + static_activation_memory_sizes_in_byte_.find(location.ToString()) == static_activation_memory_sizes_in_byte_.end(), "Memory type ", - location.name, + location.ToString(), " should only appear once."); // static_activation_memory_in_bytes_ is max virtual memory size the planner computes. // Memory dynamically allocated when executing kernels is not recorded using this field. - static_activation_memory_sizes_in_byte_[location.name] = peak_size; + static_activation_memory_sizes_in_byte_[location.ToString()] = peak_size; #endif // the memory pattern buffer will leave in the whole execution. #ifdef ORT_ENABLE_STREAM @@ -535,7 +536,7 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va AllocatorPtr alloc = nullptr; // if we have pre-calculated memory pattern, and the ort_value is not output mlvalue - // try to allocated on pre-allocated big chunk. + // try to allocate on pre-allocated big chunk. const auto& per_alloc_plan = GetAllocationPlan(ort_value_index); if (mem_patterns_ && per_alloc_plan.alloc_kind != AllocKind::kAllocateOutput && @@ -557,11 +558,11 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va } else { // the block size may vary especially if the model has NonZero ops, or different sequence lengths are // fed in, so use VERBOSE as the log level as it's expected. - // TODO: Should we re-use the block if the size is large enough? Would probably need to allow it + // TODO: Should we reuse the block if the size is large enough? Would probably need to allow it // to be freed if the size difference was too large so our memory usage doesn't stick at a high water mark LOGS(session_state_.Logger(), VERBOSE) << "For ort_value with index: " << ort_value_index << ", block in memory pattern size is: " << block->size_ - << " but the actually size is: " << size + << " but the actual size is: " << size << ", fall back to default allocation behavior"; } } @@ -572,6 +573,8 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va // no memory pattern, or the pattern is not correct. if (!alloc) alloc = GetAllocator(location); + ORT_ENFORCE(alloc && alloc.get() != nullptr, "Failed to get allocator for ", location.ToString()); + Stream* current_stream = GetValueStream(ort_value_index); if (current_stream) { #ifdef ORT_ENABLE_STREAM @@ -606,7 +609,7 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va // Dynamic activation size would be accessed by multiple threads // if parallel executor is used. std::unique_lock lock(mtx_); - dynamic_activation_memory_sizes_in_byte_[location.name] += size; + dynamic_activation_memory_sizes_in_byte_[location.ToString()] += size; session_state_.GetMemoryProfiler()->GetMemoryInfo().SetDynamicAllocation(ort_value_index); #endif } @@ -825,7 +828,7 @@ AllocatorPtr ExecutionFrame::GetAllocatorImpl(const OrtDevice& info) const { } // This method is not thread safe! -// Return S_OK and nullptr if index map to an value that is an unused optional input/output +// Return S_OK and nullptr if index map to a value that is an unused optional input/output Status ExecutionFrame::CreateNodeOutputMLValueImpl(OrtValue& ort_value, int ort_value_idx, const TensorShape* shape) { return AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape); } @@ -930,7 +933,7 @@ bool ExecutionFrame::TryGetInferredShape(int index, TensorShape& shape) const { } // Search for inferred shape. - // If inferred shape is found, it's assigned to "shape" so that caller can use it. + // If the inferred shape is found, it's assigned to "shape" so that caller can use it. if (inferred_shapes_ != nullptr) { auto it = inferred_shapes_->find(ort_value_idx); if (it != inferred_shapes_->end()) { diff --git a/onnxruntime/core/framework/memory_info.cc b/onnxruntime/core/framework/memory_info.cc index 41c8191778..afb338f47f 100644 --- a/onnxruntime/core/framework/memory_info.cc +++ b/onnxruntime/core/framework/memory_info.cc @@ -130,7 +130,7 @@ void PrintInforPerTensor(const MemoryInfo::AllocInfoPerTensor& alloc_info, const std::cout << "Index: " << alloc_info.mlvalue_index << ", "; std::cout << "Reuse inplace: " << alloc_info.inplace_reuse << ", "; std::cout << "Alloc type: " << alloc_info.alloc_kind << ", "; - std::cout << "Location: " << alloc_info.location.name << ", "; + std::cout << "Location: " << alloc_info.location.ToString() << ", "; std::cout << "lifetime: (" << alloc_info.lifetime_interval.first << ", " << alloc_info.lifetime_interval.second << "), "; std::cout << "planned block: (" << mem_info.planned_block.offset_ << ", " << (mem_info.planned_block.offset_ + mem_info.planned_block.size_) << "), "; @@ -142,24 +142,24 @@ void PrintInforPerTensor(const MemoryInfo::AllocInfoPerTensor& alloc_info, const void MemoryInfo::PrintMemoryInfoForLocation(const OrtDevice::DeviceType location) { for (const auto& map : tensors_memory_info_map_) { - if (map.first.device.Type() != location) continue; - std::cout << "Initializer in " << map.first.name << "\n"; + if (map.first.Type() != location) continue; + std::cout << "Initializer in " << map.first.ToString() << "\n"; const auto& initailizer_map = map.second.at(MapType::Initializer); for (const auto& item : initailizer_map) { - if (AllocPlan(item.first)->location.device.Type() != location) continue; + if (AllocPlan(item.first)->location.Type() != location) continue; PrintInforPerTensor(*AllocPlan(item.first), item.second, initailizer_map.GetAllocAddress(item.first)); } - std::cout << "\nStatic Activation in " << map.first.name << "\n"; + std::cout << "\nStatic Activation in " << map.first.ToString() << "\n"; const auto& static_activation_map = map.second.at(MapType::StaticActivation); for (const auto& item : static_activation_map) { - if (AllocPlan(item.first)->location.device.Type() != location) continue; + if (AllocPlan(item.first)->location.Type() != location) continue; PrintInforPerTensor(*AllocPlan(item.first), item.second, static_activation_map.GetAllocAddress(item.first)); } - std::cout << "\nDynamic Activation in " << map.first.name << "\n"; + std::cout << "\nDynamic Activation in " << map.first.ToString() << "\n"; const auto& dynamic_activation_map = map.second.at(MapType::DynamicActivation); for (const auto& item : dynamic_activation_map) { - if (AllocPlan(item.first)->location.device.Type() != location) continue; + if (AllocPlan(item.first)->location.Type() != location) continue; PrintInforPerTensor(*AllocPlan(item.first), item.second, dynamic_activation_map.GetAllocAddress(item.first)); } } @@ -273,10 +273,10 @@ void MemoryProfiler::CreateEvents(const std::string& p_name, // Create Event for each tensor auto& time_step_trace = GetMemoryInfo().time_step_trace_; for (const auto& location_map : GetMemoryInfo().tensors_memory_info_map_) { - const OrtMemoryInfo& memory_info = location_map.first; + const OrtDevice& memory_info = location_map.first; const auto& maptype_to_map_mapping = location_map.second; - if (memory_info.device.Type() != device_type) continue; + if (memory_info.Type() != device_type) continue; // If there is no tensor of a map_type, we skip creating event for that map_type if (maptype_to_map_mapping.find(map_type) == maptype_to_map_mapping.end()) continue; diff --git a/onnxruntime/core/framework/memory_info.h b/onnxruntime/core/framework/memory_info.h index 4b11dcaf7a..e38ccb3d94 100644 --- a/onnxruntime/core/framework/memory_info.h +++ b/onnxruntime/core/framework/memory_info.h @@ -122,7 +122,7 @@ class MemoryInfo { bool inplace_reuse{false}; OrtValueIndex reused_buffer{0}; // The index of the reused tensor, if no reuse, it is its own tensor. AllocKind alloc_kind{AllocKind::kAllocate}; - OrtMemoryInfo location; + OrtDevice location; }; struct AllocationSummary { @@ -185,7 +185,7 @@ class MemoryInfo { } // Key: The map type. E.g., initializer, activation. Value: A map from the tensor index to its memory information - std::map > tensors_memory_info_map_; + std::map > tensors_memory_info_map_; // Key: The tensor index. Value: The Allocation information per tensor std::map tensor_alloc_info_map_; diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index c4412648d7..e9ed36294d 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -27,6 +27,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QEmbedLayerNormalization class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QGemm); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearAdd); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearConcat); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearConvTranspose); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearWhere); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearLeakyRelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearMul); @@ -109,6 +110,7 @@ class OpSet_Microsoft_ver1 { static void ForEachSchema(std::function fn) { fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index d44e0ab155..fd6b3df934 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1439,7 +1439,7 @@ class MLAS_HALF_GEMM_POSTPROCESSOR { /** * @brief Half precision activation functions, with optional sum tensor. * Supplied sum tensor must be the same layout as the GEMM output tensor. - * And the supplied sum tensor will be added to the final result. + * And the supplied sum tensor will be added to the tensor before activation. */ class MLAS_HALF_GEMM_ACTIVATION_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR { diff --git a/onnxruntime/core/mlas/lib/activate_fp16.cpp b/onnxruntime/core/mlas/lib/activate_fp16.cpp index f62ffe0db9..776ec67fcc 100644 --- a/onnxruntime/core/mlas/lib/activate_fp16.cpp +++ b/onnxruntime/core/mlas/lib/activate_fp16.cpp @@ -689,35 +689,35 @@ MlasActivationKernel( size_t n = CountN; while (n >= 8) { - MLAS_FLOAT16X8 Vector = MlasLoadFloat16x8(buffer); MLAS_FLOAT16X8 AVec = MlasLoadFloat16x8(addsrc); + MLAS_FLOAT16X8 Vector = MlasLoadFloat16x8(buffer); addsrc += 8; - Vector = ActivationFunction.Activate(Vector); Vector = MlasAddFloat16x8(Vector, AVec); + Vector = ActivationFunction.Activate(Vector); MlasStoreFloat16x8(buffer, Vector); buffer += 8; n -= 8; } if (n >= 4) { - MLAS_FLOAT16X4 Vector = MlasLoadFloat16x4(buffer); MLAS_FLOAT16X4 AVec = MlasLoadFloat16x4(addsrc); + MLAS_FLOAT16X4 Vector = MlasLoadFloat16x4(buffer); addsrc += 4; - Vector = ActivationFunction.Activate(Vector); Vector = MlasAddFloat16x4(Vector, AVec); + Vector = ActivationFunction.Activate(Vector); MlasStoreFloat16x4(buffer, Vector); buffer += 4; n -= 4; } if (n > 0) { - MLAS_FLOAT16X4 buf; - std::memcpy(&buf, buffer, n * sizeof(_mlas_fp16_)); MLAS_FLOAT16X4 addbuf; + MLAS_FLOAT16X4 buf; std::memcpy(&addbuf, addsrc, n * sizeof(_mlas_fp16_)); - MLAS_FLOAT16X4 res = ActivationFunction.Activate(buf); - res = MlasAddFloat16x4(res, addbuf); - MlasStorePartialFloat16x4(buffer, res, n); + std::memcpy(&buf, buffer, n * sizeof(_mlas_fp16_)); + buf = MlasAddFloat16x4(buf, addbuf); + buf = ActivationFunction.Activate(buf); + MlasStorePartialFloat16x4(buffer, buf, n); } CRow += ldc; @@ -858,8 +858,6 @@ MLAS_HALF_GEMM_ACTIVATION_PROCESSOR::Process( ) const { std::vector buffer(CountM*CountN); - MLAS_HALF_GEMM_2FLOAT_PROCESSOR proc(this->Activation_, buffer.data(), CountN); - proc.Process(C, StartM, StartN, CountM, CountN, ldc); _mlas_fp16_* Output = reinterpret_cast<_mlas_fp16_*>(C); auto* CRow = buffer.data(); @@ -876,6 +874,8 @@ MLAS_HALF_GEMM_ACTIVATION_PROCESSOR::Process( } CAdd += ldc; } + MlasActivation(&this->Activation_, CRow, nullptr, 1, CountN, CountN); + CvtFloat2Half(Output, CRow, CountN); CRow += CountN; Output += ldc; diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index ea620877ed..7155a9dd27 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -782,7 +782,7 @@ extern "C" { // value. // -#define MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT 32 +#define MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT 64 // // Define the target number of per-thread multiplies before using another diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index 54fe7c4087..c090ab2a6c 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -7,6 +7,7 @@ #include "core/common/inlined_containers.h" #include "core/framework/tensorprotoutils.h" +#include "core/mlas/inc/mlas.h" #include "core/graph/graph_utils.h" #include "core/graph/node_attr_utils.h" #include "core/optimizer/utils.h" @@ -49,18 +50,30 @@ bool ConvFusionDataTypeCheck(const Node& conv_node) { // Assess the support level for the other compatible EPs and if they also // only support float, remove the EP check altogether. const std::string_view node_ep = conv_node.GetExecutionProviderType(); - if (node_ep == kCudaExecutionProvider || node_ep == kCpuExecutionProvider) { + if (node_ep == kCudaExecutionProvider) { if (!HasElementDataType(*conv_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) { return false; } } + if (node_ep == kCpuExecutionProvider) { +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED + if (!HasElementDataType(*conv_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT) && + !HasElementDataType(*conv_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)) { + return false; + } +#else + if (!HasElementDataType(*conv_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) { + return false; + } +#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED + } return true; } -class ConvActivation : public NodeSelector { +class ConvActivationSelector : public NodeSelector { public: - ConvActivation() = default; + ConvActivationSelector() = default; std::optional Select(const GraphViewer& graph_viewer, const Node& node) const override { const std::string_view node_ep = node.GetExecutionProviderType(); @@ -159,7 +172,7 @@ class ConvAddRelu : public NodeSelector { namespace actions { using NTO = NodesToOptimize; -class FuseConvActivation : public ReplaceWithNew { +class FuseConvActivationAction : public ReplaceWithNew { private: std::string OpType(const RuntimeState&) const override { return "FusedConv"; } @@ -245,9 +258,9 @@ class FuseConvAddRelu : public ReplaceWithNew { void RegisterConvActivationFusionRules(SelectorActionRegistry& registry) { const auto name = "ConvAct"; - auto action = std::make_unique(); + auto action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) - auto selector = std::make_unique(); + auto selector = std::make_unique(); registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}}, std::move(selector), std::move(action)); #else diff --git a/onnxruntime/core/optimizer/conv_add_act_fusion.cc b/onnxruntime/core/optimizer/conv_add_act_fusion.cc index 5e34bce362..7c8bfeaec5 100644 --- a/onnxruntime/core/optimizer/conv_add_act_fusion.cc +++ b/onnxruntime/core/optimizer/conv_add_act_fusion.cc @@ -40,20 +40,28 @@ const Node* GetLoneConsumerNode(const GraphViewer& graph_viewer, const Node& nod return &*node.OutputNodesBegin(); } -class ConvAddActivation : public NodeSelector { +class ConvAddActivationSelector : public NodeSelector { public: - ConvAddActivation() = default; - + ConvAddActivationSelector() = default; std::optional Select(const GraphViewer& graph_viewer, const Node& node) const override { const std::string_view node_ep = node.GetExecutionProviderType(); - if (node_ep != kCpuExecutionProvider || !HasElementDataType(*node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) { +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED + if (node_ep != kCpuExecutionProvider || + (!HasElementDataType(*node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT) && + !HasElementDataType(*node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT16))) { + return std::nullopt; + } +#else + if (node_ep != kCpuExecutionProvider || + !HasElementDataType(*node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) { return std::nullopt; } +#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED // we can't assign `conv_node` as the producer-node, even it is, because we have to make sure // 1. Its type is 'conv', 2. it has to satisfy the other requirements,like shape, please refer to SelectConvProducer for more info const Node* conv_node = nullptr; const auto* add_node = GetLoneConsumerNode(graph_viewer, node); - if (!add_node) { + if (add_node == nullptr) { return std::nullopt; } // Let's support addition first, leave any-element-wise-op fusion in the future. @@ -64,13 +72,13 @@ class ConvAddActivation : public NodeSelector { if (graph_utils::IsSupportedOptypeVersionAndDomain(*add_node, "Add", {7, 13, 14})) { conv_node = SelectProducerConv(*add_node); } - if (!conv_node) { + if (conv_node == nullptr) { return std::nullopt; } // GetLoneConsumerNode will ensure outputedge_count is 1 const auto* act_node = GetLoneConsumerNode(graph_viewer, *add_node); // even the next node is not a activation node, it's also fine. - if (!act_node) { + if (act_node == nullptr) { // we can't fuse add-activation when add_node has multiple consumer nodes act_node = nullptr; } else if (SelectActivation(graph_viewer, *act_node)) { @@ -82,7 +90,7 @@ class ConvAddActivation : public NodeSelector { NodesToOptimizeIndicesBuilder builder{}; builder.target_node = conv_node->Index(); builder.output_nodes = {add_node->Index()}; - if (act_node) { + if (act_node != nullptr) { builder.output_nodes.push_back(act_node->Index()); } return builder.Build(); @@ -167,17 +175,27 @@ class ConvAddActivation : public NodeSelector { // Check if this is a single use convolution that hasn't already // been fused with another Add/Sum node. The Add/Sum can also only be // fused if the convolution isn't itself fused with an activation. - if ((inputs_node[n]->OpType() == "Conv") && (pre_input_defs_count < 4) && (producer_input_args_count.size() < 4) && - (graph_utils::GetNodeAttribute(*inputs_node[n], "activation") == nullptr) && (inputs_node[n]->GetOutputEdgesCount() == 1)) { - if (pre_input_defs_count < 3) { - // The optional bias parameter is empty so set to an empty string. + if ((inputs_node[n]->OpType() == "Conv") && (pre_input_defs_count < 4) && + (producer_input_args_count.size() < 4) && + (graph_utils::GetNodeAttribute(*inputs_node[n], "activation") == nullptr) && + (inputs_node[n]->GetOutputEdgesCount() == 1)) { + if (pre_input_defs_count < 3) { // The optional bias parameter is empty so set to an empty string. + // TODO, add a new null arguments for bias + continue; + } + return inputs_node[n]; + } + if (inputs_node[n]->OpType() == "NhwcFusedConv" && (pre_input_defs_count < 4) && + (producer_input_args_count.size() < 5) && + (graph_utils::GetNodeAttribute(*inputs_node[n], "activation") == nullptr) && + (inputs_node[n]->GetOutputEdgesCount() == 1)) { + if (pre_input_defs_count < 3) { // The optional bias parameter is empty so set to an empty string. // TODO, add a new null arguments for bias continue; } return inputs_node[n]; } } - return nullptr; } }; @@ -187,9 +205,14 @@ class ConvAddActivation : public NodeSelector { namespace actions { using NTO = NodesToOptimize; -class FuseConvAddActivation : public ReplaceWithNew { +class FuseConvAddActivationAction : public ReplaceWithNew { + public: + FuseConvAddActivationAction() = default; + private: - std::string OpType(const RuntimeState&) const override { return "FusedConv"; } + std::string OpType(const RuntimeState& runtimeState) const override { + return (runtimeState.selected_nodes.Target().OpType() == "Conv") ? "FusedConv" : "NhwcFusedConv"; + } std::string Domain(const RuntimeState&) const override { return kMSDomain; } @@ -262,11 +285,14 @@ class FuseConvAddActivation : public ReplaceWithNew { } // namespace actions void RegisterConvAddActivationFusionRules(SelectorActionRegistry& registry) { - const auto name = "ConvAddAct"; - auto action = std::make_unique(); - auto selector = std::make_unique(); - registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}}, + auto action = std::make_unique(); + auto selector = std::make_unique(); + registry.RegisterSelectorAndAction("ConvAddAct", {{"Conv", {1, 11}}}, std::move(selector), std::move(action)); + auto action_nhwc = std::make_unique(); + auto selector_nhwc = std::make_unique(); + registry.RegisterSelectorAndAction("NhwcFusedConvAct", {{"NhwcFusedConv", {1, 11}}}, + std::move(selector_nhwc), std::move(action_nhwc)); } SelectorActionRegistry CreateSelectorActionRegistry() { diff --git a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc index 540e0e92d3..e182b6c695 100644 --- a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc @@ -93,9 +93,10 @@ static Status MatchAndProcess( Status status = Status::OK(); do { - // TODO: for now this just needs to support ONNX ops. If we ever had a transformer that was going to - // target non-ONNX ops we'd need to rework a few things to include the op domain in the matches - if (node.Domain() != kOnnxDomain) { + // TODO: for now this just needs to support ONNX and Micrsoft Domain ops. + // If we ever had a transformer that was going to target non-ONNX ops, + // we'd need to rework a few things to include the op domain in the matches + if (node.Domain() != kOnnxDomain && node.Domain() != kMSDomain) { break; } diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index b9617accd6..022ecc30f3 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -137,12 +137,20 @@ class WindowsThread : public EnvThread { static unsigned __stdcall ThreadMain(void* param) { std::unique_ptr p(static_cast(param)); - const ORTCHAR_T* name_prefix = - (p->name_prefix == nullptr || wcslen(p->name_prefix) == 0) ? L"onnxruntime" : p->name_prefix; - std::wostringstream oss; - oss << name_prefix << "-" << p->index; - // Ignore the error - (void)SetThreadDescription(GetCurrentThread(), oss.str().c_str()); + // Not all machines have kernel32.dll and/or SetThreadDescription (e.g. Azure App Service sandbox) + // so we need to ensure it's available before calling. + HMODULE kernelModule = GetModuleHandle(TEXT("kernel32.dll")); + if (kernelModule != nullptr) { + auto setThreadDescriptionFn = (SetThreadDescriptionFunc)GetProcAddress(kernelModule, "SetThreadDescription"); + if (setThreadDescriptionFn != nullptr) { + const ORTCHAR_T* name_prefix = (p->name_prefix == nullptr || wcslen(p->name_prefix) == 0) ? L"onnxruntime" + : p->name_prefix; + std::wostringstream oss; + oss << name_prefix << "-" << p->index; + // Ignore any errors + (void)(setThreadDescriptionFn)(GetCurrentThread(), oss.str().c_str()); + } + } unsigned ret = 0; ORT_TRY { diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index d06946579b..6db63752c5 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1511,7 +1511,7 @@ void CANNExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& OrtDevice CANNExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) { - return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CANN_PINNED, default_device_.Id()); + return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CANN_PINNED, 0 /*CPU device id always be 0*/); } return default_device_; } diff --git a/onnxruntime/core/providers/coreml/builders/helper.cc b/onnxruntime/core/providers/coreml/builders/helper.cc index f2055aa8e8..d062d59de8 100644 --- a/onnxruntime/core/providers/coreml/builders/helper.cc +++ b/onnxruntime/core/providers/coreml/builders/helper.cc @@ -46,6 +46,11 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const lo } bool IsInputSupported(const NodeArg& input, const std::string& parent_name, const logging::Logger& logger) { + if (!input.Exists()) { + // optional input that is not provided + return true; + } + const auto& input_name = input.Name(); const auto* shape_proto = input.Shape(); // We do not support input with no shape @@ -86,12 +91,6 @@ std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewe } #endif - const auto& graph_inputs = graph_viewer.GetInputs(); - if (std::any_of(graph_inputs.begin(), graph_inputs.end(), - [&](const NodeArg* input) { return !IsInputSupported(*input, "graph", logger); })) { - return supported_nodes; - } - for (const auto& node : graph_viewer.Nodes()) { const bool supported = IsNodeSupported(node, graph_viewer, logger); LOGS(logger, VERBOSE) << "Operator type: [" << node.OpType() diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc index d4bf25ab64..e6867f1081 100644 --- a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc +++ b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc @@ -32,7 +32,7 @@ using ConvPadVector = ConvAttributes::ConvPadVector; * 2. Activation * It takes an operator attribute 'activation', which supplies the activation info. * - * Add is performed AFTER activation. + * Add is performed BEFORE activation. * * The implementation supports both NCHW and NHWC. It runs faster with NHWC. * @@ -281,12 +281,10 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { if (Y->Shape().Size() == 0) { return Status::OK(); } - if (Sum) { - if (Sum->Shape() != Y->Shape()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Z shape does not match output shape.", - " Z: ", Sum->Shape().ToString().c_str(), - " Output: ", Y->Shape().ToString().c_str()); - } + if (Sum && Sum->Shape() != Y->Shape()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Z shape does not match output shape.", + " Z: ", Sum->Shape().ToString().c_str(), + " Output: ", Y->Shape().ToString().c_str()); } const int64_t input_image_size = input_shape.Size(); @@ -338,7 +336,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { const auto* Xdata = X->Data(); const auto* Bdata = B != nullptr ? B->Data() : nullptr; auto* Ydata = Y->MutableData(); - const auto* SumData = Sum != nullptr ? Sum->Data() : nullptr; + const auto* sum_data = Sum != nullptr ? Sum->Data() : nullptr; BufferUniquePtr transpose_input_buffer; BufferUniquePtr transpose_output_buffer; @@ -409,7 +407,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { for (int64_t image_id = 0; image_id < N; ++image_id) { const auto* input_data = Xdata; auto* output_data = Ydata; - const auto* add_src = SumData; + const auto* add_src = sum_data; if (!channels_last_) { // Transpose the input from channels first (CHW) to channels last (HWC). @@ -478,7 +476,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { static_cast(M), static_cast(output_count), static_cast(kernel_size), - &act); + (!channels_last_ && sum_data) ? nullptr : &act); } else { for (int64_t group_id = 0; group_id < group_count; ++group_id) { // Prepare the im2col transformation or use the input buffer directly for @@ -554,7 +552,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { gemm_params.C = worker_output + group_id * group_output_channels; gemm_params.ldc = static_cast(M); gemm_params.Bias = Bdata; - gemm_params.OutputProcessor = &act; // process fused activation and add + gemm_params.OutputProcessor = (!channels_last_ && sum_data) ? nullptr : &act; // process fused activation and add MlasHalfGemmBatch( static_cast(output_count), @@ -574,10 +572,8 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { Ydata, static_cast(output_image_size), static_cast(M)); - if (SumData != nullptr) { - MLAS_ACTIVATION activation; - activation.ActivationKind = MlasIdentityActivation; - MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc(activation, SumData); + if (sum_data != nullptr) { + MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc(activation_, sum_data); proc.Process(Ydata, 0, 0, static_cast(M), static_cast(output_image_size), static_cast(output_image_size)); @@ -586,8 +582,8 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const { Xdata += X_offset; Ydata += Y_offset; - if (SumData != nullptr) { - SumData += Y_offset; + if (sum_data != nullptr) { + sum_data += Y_offset; } } diff --git a/onnxruntime/core/providers/cpu/math/matmul_helper.h b/onnxruntime/core/providers/cpu/math/matmul_helper.h index 7c86f11b97..d7275ee324 100644 --- a/onnxruntime/core/providers/cpu/math/matmul_helper.h +++ b/onnxruntime/core/providers/cpu/math/matmul_helper.h @@ -371,9 +371,24 @@ class MatMulComputeHelper { return right_zp_offsets_; } + static bool IsAligned(const std::vector& offsets) { + constexpr size_t alignment = 16; + const auto len = offsets.size(); + for (size_t i = 0; i < len; i++) { + if ((offsets[i] % alignment) != 0) { + return false; + } + } + return true; + } + + bool IsBatchedGemmAligned() const { + return IsAligned(left_offsets_) && IsAligned(right_offsets_) && IsAligned(output_offsets_); + } + template static void OffsetToArrays(T* p, const std::vector& offsets, gsl::span arrays) { - auto len = offsets.size(); + const auto len = offsets.size(); ORT_ENFORCE(arrays.size() == len); for (size_t i = 0; i < len; i++) { arrays[i] = p + offsets[i]; @@ -382,7 +397,7 @@ class MatMulComputeHelper { template static void OffsetToArrays(const T* p, const std::vector& offsets, gsl::span arrays) { - auto len = offsets.size(); + const auto len = offsets.size(); ORT_ENFORCE(arrays.size() == len); for (size_t i = 0; i < len; i++) { arrays[i] = p + offsets[i]; diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h b/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h index a4d67ec63f..984f4795ce 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h @@ -44,14 +44,16 @@ struct ConvTransposeAttributes : public ConvAttributes { }; Status PrepareForCompute(OpKernelContext* context, bool has_bias, Prepare& p, - bool dynamic_padding = false, const TensorShape* filter_shape = nullptr) const { + bool dynamic_padding = false, const TensorShape* filter_shape = nullptr, bool is_quant = false) const { const Tensor* X = context->Input(0); - const Tensor* F = (filter_shape != nullptr) ? nullptr : context->Input(1); + const Tensor* F = (filter_shape != nullptr) ? nullptr : context->Input(is_quant ? 3 : 1); const TensorShape& F_Shape = (filter_shape != nullptr) ? *filter_shape : F->Shape(); + if (dynamic_padding && is_quant) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "dynamic padding is not supported for quantized conv tranpose."); + } const Tensor* Pads = dynamic_padding ? context->Input(2) : nullptr; - const Tensor* B = has_bias ? (dynamic_padding ? context->Input(3) : context->Input(2)) : nullptr; + const Tensor* B = has_bias ? (dynamic_padding ? context->Input(is_quant ? 9 : 3) : context->Input(is_quant ? 8 : 2)) : nullptr; TensorShape input_shape = X->Shape().Slice(2); - const int64_t num_input_channels = X->Shape()[1]; const int64_t N = X->Shape()[0]; const int64_t num_output_channels_multiplier = F_Shape[1]; diff --git a/onnxruntime/core/providers/cpu/quantization/qlinearconvtranspose.cc b/onnxruntime/core/providers/cpu/quantization/qlinearconvtranspose.cc new file mode 100644 index 0000000000..513f72030b --- /dev/null +++ b/onnxruntime/core/providers/cpu/quantization/qlinearconvtranspose.cc @@ -0,0 +1,254 @@ +/** +* Copyright (c) 2014-present, Quadric, Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +/* Modifications Copyright (c) Microsoft. */ + +#include "core/mlas/inc/mlas.h" +#include "core/common/safeint.h" +#include "core/util/math.h" +#include "core/util/math_cpuonly.h" + +#include "core/common/inlined_containers_fwd.h" +#include "core/framework/transpose_helper.h" +#include "core/providers/utils.h" +#include "core/framework/tensorprotoutils.h" + +#include "core/providers/cpu/nn/conv_transpose_attributes.h" + +namespace onnxruntime { + +template +class QLinearConvTranspose : public OpKernel { + public: + explicit QLinearConvTranspose(const OpKernelInfo& info) : OpKernel(info), conv_transpose_attrs_(info) { + } + + Status Compute(OpKernelContext* context) const override; + + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override; + + Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + int input_idx, + /*out*/ bool& used_shared_buffers) override; + + protected: + Status DoConvTranspose(OpKernelContext* context) const; + + private: + static float ComputeOutputScale(OpKernelContext* context) { + const Tensor* X_scale = context->Input(InputTensors::IN_X_SCALE); + const Tensor* W_scale = context->Input(InputTensors::IN_W_SCALE); + const Tensor* Y_scale = context->Input(InputTensors::IN_Y_SCALE); + ORT_ENFORCE(IsScalarOr1ElementVector(X_scale), + "QLinearConv : input scale must be a scalar or 1D tensor of size 1"); + ORT_ENFORCE(IsScalarOr1ElementVector(Y_scale), + "QLinearConv : result scale must be a scalar or 1D tensor of size 1"); + ORT_ENFORCE(IsScalarOr1ElementVector(W_scale), + "QLinearConv : filter scale must be a scalar or 1D tensor of size 1"); + + auto X_scale_value = *(X_scale->Data()); + auto Y_scale_value = *(Y_scale->Data()); + auto W_scale_value = *(W_scale->Data()); + + return X_scale_value * W_scale_value / Y_scale_value; + } + + enum InputTensors : int { + IN_X = 0, + IN_X_SCALE = 1, + IN_X_ZERO_POINT = 2, + IN_W = 3, + IN_W_SCALE = 4, + IN_W_ZERO_POINT = 5, + IN_Y_SCALE = 6, + IN_Y_ZERO_POINT = 7, + IN_BIAS = 8 + + }; + enum OutputTensors : int { + OUT_Y = 0 + }; + + ConvTransposeAttributes conv_transpose_attrs_; + + // for pre-packing usage + TensorShape filter_shape_; + BufferUniquePtr transposed_filter_; +}; + +template +Status QLinearConvTranspose::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* /*prepacked_weights*/ +) { + is_packed = false; + return Status::OK(); +} + +template +Status QLinearConvTranspose::UseSharedPrePackedBuffers(std::vector& /*prepacked_buffers*/, + int /*input_idx*/, + /*out*/ bool& used_shared_buffers) { + used_shared_buffers = false; + return Status::OK(); +} + +template +Status QLinearConvTranspose::Compute(OpKernelContext* context) const { + return QLinearConvTranspose::DoConvTranspose(context); +} + +template +Status QLinearConvTranspose::DoConvTranspose(OpKernelContext* context) const { + typedef int32_t ActType; + size_t num_inputs = OpKernel::Node().InputDefs().size(); + ConvTransposeAttributes::Prepare p; + bool has_bias = num_inputs == 9; + ORT_RETURN_IF_ERROR(conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, false, nullptr, true)); + + const int64_t input_image_size = p.input_shape.Size(); + + // Bail out early if one of the dimensions is zero. + if (p.Y->Shape().Size() == 0) { + return Status::OK(); + } + + // Quantization parameters, only support symmetric quant for now + float scale_value = ComputeOutputScale(context); + + const int64_t X_offset = p.num_input_channels / conv_transpose_attrs_.group * input_image_size; + const int64_t Y_offset = p.Y->Shape().Size() / p.Y->Shape()[0] / conv_transpose_attrs_.group; + const int64_t W_offset = p.F->Shape().Size() / conv_transpose_attrs_.group; + const int64_t kernel_size = TensorShape(p.kernel_shape).Size(); + const int64_t kernel_dim = p.num_output_channels / conv_transpose_attrs_.group * kernel_size; + const int64_t output_size = (p.Y->Shape().Slice(2)).Size(); + + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); + + const int64_t col_buffer_size = kernel_dim * p.input_shape.Size(); + auto col_data = alloc->Alloc(SafeInt(sizeof(ActType)) * col_buffer_size); + BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc)); + + // Pre-transpose the weight matrix because MatMul does not take transpose params + const int64_t trans_filt_size = p.num_input_channels / conv_transpose_attrs_.group * kernel_dim; + auto trans_filt_data = alloc->Alloc(SafeInt(sizeof(T)) * trans_filt_size); + BufferUniquePtr trans_filt(trans_filt_data, BufferDeleter(alloc)); + ActType* col_buffer_data = static_cast(col_buffer.get()); + + const T* Xdata = p.X->Data(); + const T* filter_data = p.F->Data(); + T* Ydata = p.Y->MutableData(); + TensorShape output_shape = p.Y->Shape().Slice(2); + MlasTranspose( + filter_data, + static_cast(trans_filt.get()), + p.num_input_channels / conv_transpose_attrs_.group, + kernel_dim); + + //TODO: add support for assymmetric quantization and using int8 MlassGemm. This will require offseting the + //input data to an uint8 range and passing a zero point parameter. + //For now compute the GEMM in int32, as MlassGemm requires uint8 input and MlasSymmQgemmBatch is ARM only + auto inp_i32_buffer = alloc->Alloc(SafeInt(sizeof(ActType)) * p.X->Shape().Size()); + BufferUniquePtr inp_i32(inp_i32_buffer, BufferDeleter(alloc)); + ActType* inp_i32_data = static_cast(inp_i32.get()); + for (std::int64_t i = 0; i < p.X->Shape().Size(); i++) { + inp_i32_data[i] = Xdata[i]; + } + auto wt_i32_buffer = alloc->Alloc(SafeInt(sizeof(ActType)) * p.F->Shape().Size()); + BufferUniquePtr wt_i32(wt_i32_buffer, BufferDeleter(alloc)); + ActType* wt_i32_data = static_cast(wt_i32.get()); + for (std::int64_t i = 0; i < p.F->Shape().Size(); i++) { + wt_i32_data[i] = static_cast(trans_filt.get())[i]; + } + auto out_i32_buffer = alloc->Alloc(SafeInt(sizeof(ActType)) * p.Y->Shape().Size()); + BufferUniquePtr out_i32(out_i32_buffer, BufferDeleter(std::move(alloc))); + ActType* out_i32_data = static_cast(out_i32.get()); + + for (auto image_id = 0; image_id < p.N; ++image_id) { + for (int group_id = 0; group_id < conv_transpose_attrs_.group; ++group_id) { + math::MatMul(kernel_dim, input_image_size, p.num_input_channels / conv_transpose_attrs_.group, + wt_i32_data + group_id * W_offset, + inp_i32_data + group_id * X_offset, + col_buffer_data, + nullptr); + + math::Col2im( + reinterpret_cast(col_buffer_data), + p.num_output_channels / conv_transpose_attrs_.group, + p.Y->Shape()[2], + p.Y->Shape()[3], + p.kernel_shape[0], + p.kernel_shape[1], + p.dilations[0], + p.dilations[1], + p.pads[0], + p.pads[1], + p.pads[2], + p.pads[3], + p.strides[0], + p.strides[1], + reinterpret_cast(out_i32_data + group_id * Y_offset), + &CPUMathUtil::Instance()); + } + + if (p.B != nullptr) { + auto out_i32_matrix = EigenMatrixMap(out_i32_data, output_size, p.num_output_channels); + auto Bvec = ConstEigenVectorMap(p.B->Data(), p.num_output_channels); + out_i32_matrix.rowwise() += Bvec.transpose(); + } + + MlasRequantizeOutput(out_i32_data, + p.Y->Shape()[2] * p.Y->Shape()[3], + Ydata, + p.Y->Shape()[2] * p.Y->Shape()[3], + nullptr, + &scale_value, + false, + (T)0, + 0, + 0, + p.num_output_channels, + p.Y->Shape()[2] * p.Y->Shape()[3]); + Xdata += X_offset * conv_transpose_attrs_.group; + Ydata += Y_offset * conv_transpose_attrs_.group; + } + + return Status::OK(); +} + +#ifndef DISABLE_CONTRIB_OPS +namespace contrib { + +// Register the operator with int8 inputs and weights +ONNX_OPERATOR_KERNEL_EX( + QLinearConvTranspose, + kMSDomain, + 1, + //int8_t, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), + QLinearConvTranspose); + +} // namespace contrib +#endif + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 0d3c54d676..2cd0467c77 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -2532,9 +2532,8 @@ void CUDAExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& } OrtDevice CUDAExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { - if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) { - return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, default_device_.Id()); - } + if (mem_type == OrtMemTypeCPUInput) return OrtDevice(); + if (mem_type == OrtMemTypeCPUOutput) return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, 0 /*CPU device id always be 0*/); return default_device_; } diff --git a/onnxruntime/core/providers/cuda/math/matmul.cc b/onnxruntime/core/providers/cuda/math/matmul.cc index 9d7e673109..5a813a98a6 100644 --- a/onnxruntime/core/providers/cuda/math/matmul.cc +++ b/onnxruntime/core/providers/cuda/math/matmul.cc @@ -180,6 +180,14 @@ Status MatMul::ComputeInternal(OpKernelContext* ctx) const { ORT_RETURN_IF_ERROR(right_arrays.CopyToGpu(ctx->GetComputeStream())); ORT_RETURN_IF_ERROR(output_arrays.CopyToGpu(ctx->GetComputeStream())); + // TF32 provides a huge performance gain for training and inference while preserving FP32 levels of accuracy. + // It requires Ampere or newer GPU, and pointers of matrics shall be aligned (ideal alignment is 16-byte). + // Assume that start memory of input/output tensor is aligned, we only check offsets of sub-matrix per batch here. + cublasMath_t mode = (std::is_same::value && device_prop.major >= 8 && helper.IsBatchedGemmAligned()) + ? CUBLAS_TF32_TENSOR_OP_MATH + : CUBLAS_DEFAULT_MATH; + CublasMathModeSetter math_mode_setter(device_prop, GetCublasHandle(ctx), mode); + // note that onnxruntime OrtValue is row major, while cublas is column major, // so swap left/right operands CUBLAS_RETURN_IF_ERROR(cublasGemmBatchedHelper( diff --git a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h index 8bf429949e..8424d14d39 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h +++ b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h @@ -185,13 +185,7 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, const float* beta, float* Carray[], int ldc, int batch_count, - const cudaDeviceProp& prop) { -#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 - onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, CUBLAS_TF32_TENSOR_OP_MATH); -#else - ORT_UNUSED_PARAMETER(prop); -#endif - + const cudaDeviceProp&) { return cublasSgemmBatched(handle, transa, transb, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp index f5abeb5d7b..aae73dca46 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp @@ -244,17 +244,12 @@ namespace Dml void* CPUAllocator::Alloc(size_t size) { - if (size <= 0) - { - return nullptr; - } - void* p = malloc(size); - return p; + return onnxruntime::AllocatorDefaultAlloc(size); } void CPUAllocator::Free(void* p) { - free(p); + return onnxruntime::AllocatorDefaultFree(p); } } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 55322bea6b..2f1890bb0b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -210,6 +210,8 @@ namespace Dml m_cpuOutputAllocator = std::make_shared(OrtMemType::OrtMemTypeCPUOutput); CreateDmlKernelRegistry(&m_kernelRegistry, &m_internalRegInfoMap); + + m_lastUploadFlushTime = std::chrono::steady_clock::now(); } HRESULT __stdcall ExecutionProviderImpl::GetD3DDevice(_COM_Outptr_ ID3D12Device** d3dDevice) const noexcept @@ -447,6 +449,7 @@ namespace Dml const auto dstState = D3D12_RESOURCE_STATE_UNORDERED_ACCESS; // GPU resources are always kept in UAV state m_uploadHeap->BeginUploadToGpu(dstData, dstOffset, dstState, AsByteSpan(srcData, dataSizeInBytes)); + FlushUploadsIfReady(); } else if (!src->IsCpuData() && dst->IsCpuData()) { @@ -566,12 +569,23 @@ namespace Dml assert(!m_closed); m_uploadHeap->BeginUploadToGpu(dstData, 0, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, AsByteSpan(srcData, static_cast(srcDataSize))); + FlushUploadsIfReady(); return S_OK; } ORT_CATCH_RETURN } + void ExecutionProviderImpl::FlushUploadsIfReady() const + { + // Periodically flush uploads to make sure the GPU is not idle for too long + if (std::chrono::steady_clock::now() - m_lastUploadFlushTime > m_batchFlushInterval) + { + Flush(); + m_lastUploadFlushTime = std::chrono::steady_clock::now(); + } + } + uint32_t ExecutionProviderImpl::GetSupportedDeviceDataTypeMask() const { // The DML provider registers all supported kernels up-front regardless of actual device capability, @@ -661,9 +675,10 @@ namespace Dml bool IsCustomOpShader(const onnxruntime::Node& node) { - auto custom_ops = std::array{ + auto custom_ops = std::array{ "DFT", - "STFT" + "STFT", + "GridSample" }; for (auto& custom_op : custom_ops) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index b9ac772095..04f3420f96 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -180,6 +180,8 @@ namespace Dml uint32_t supportedDeviceDataTypeMask // Each bit corresponds to each DML_TENSOR_DATA_TYPE. ) const; + void FlushUploadsIfReady() const; + ComPtr m_d3d12Device; ComPtr m_dmlDevice; bool m_isMcdmDevice = false; @@ -195,6 +197,8 @@ namespace Dml std::shared_ptr m_internalRegInfoMap; mutable uint64_t m_partitionKernelPrefixVal = 0; bool m_closed = false; + mutable std::chrono::time_point m_lastUploadFlushTime; + static constexpr std::chrono::milliseconds m_batchFlushInterval = std::chrono::milliseconds(10); }; class DataTransfer : public onnxruntime::IDataTransfer diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiHelpers.h index 76ca37bd05..9a1c23093f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiHelpers.h @@ -183,13 +183,13 @@ class StackAllocator static T RoundUpToMultiple(T value, T multiple) { static_assert(std::is_integral_v); - + T remainder = value % multiple; if (remainder != 0) { value += multiple - remainder; } - + return value; } @@ -231,4 +231,4 @@ class StackAllocator // allocated memory if the fixed stack array is exhausted. FixedBucket m_fixed; std::deque m_dynamic; -}; \ No newline at end of file +}; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index 8b5cf36937..c75b662af7 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -1155,6 +1155,11 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_GELU; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MULTIHEAD_ATTENTION; +}; template struct OperatorTypeTraits @@ -2139,14 +2144,20 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_GELU> using DescType = DML_ACTIVATION_GELU_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MULTIHEAD_ATTENTION> +{ + using DescType = DML_MULTIHEAD_ATTENTION_OPERATOR_DESC; +}; + // Calls a visitor functor, supplying an empty operator desc corresponding to the given DML_OPERATOR_TYPE as // the first argument. -// +// // For example: // Visit(DML_OPERATOR_ELEMENT_WISE_IDENTITY, [](auto tag) { // using T = decltype(tag); // T is one of the DML_*_OPERATOR_DESC structs // }); -// +// #pragma warning(push) #pragma warning(disable:4702) template @@ -2432,6 +2443,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_RESAMPLE_GRAD1_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_DIAGONAL_MATRIX1: return std::invoke(std::forward(visitor), DML_DIAGONAL_MATRIX1_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_MULTIHEAD_ATTENTION: + return std::invoke(std::forward(visitor), DML_MULTIHEAD_ATTENTION_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_ELU: return std::invoke(std::forward(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_CELU: @@ -2633,6 +2646,7 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_RESAMPLE2: return "DML_OPERATOR_RESAMPLE2"; case DML_OPERATOR_RESAMPLE_GRAD1: return "DML_OPERATOR_RESAMPLE_GRAD1"; case DML_OPERATOR_DIAGONAL_MATRIX1: return "DML_OPERATOR_DIAGONAL_MATRIX1"; + case DML_OPERATOR_MULTIHEAD_ATTENTION: return "DML_OPERATOR_MULTIHEAD_ATTENTION"; default: assert(false); return ""; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index 42de619a87..1ebd52d4ed 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -2302,6 +2302,35 @@ constexpr DML_OPERATOR_SCHEMA DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA{ DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA_FIELDS[18] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "QueryTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "KeyTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ValueTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "StackedQueryKeyTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "StackedKeyValueTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "StackedQueryKeyValueTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "MaskTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "RelativePositionBiasTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "PastKeyTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "PastValueTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputPresentKeyTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputPresentValueTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Scale", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "MaskFilterValue", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "HeadCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MaskType", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA { + "DML_OPERATOR_MULTIHEAD_ATTENTION", + DML_OPERATOR_MULTIHEAD_ATTENTION, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 18, + DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLX.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLX.h index f7feb5ff21..7e050eef50 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLX.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLX.h @@ -202,7 +202,7 @@ namespace dml }; } -#if DMLX_USE_ABSEIL +#if DMLX_USE_ABSEIL template using Optional = absl::optional; @@ -231,7 +231,7 @@ namespace dml #elif DMLX_USE_GSL template using Span = gsl::span; - #else + #else template using Span = dml::detail::span; #endif @@ -245,11 +245,11 @@ namespace dml #define DMLX_THROW(_hr) THROW_HR(_hr) #else #define DMLX_THROW_IF_FAILED(_hr) if (FAILED(_hr)) { throw std::runtime_error(#_hr); } - #define DMLX_THROW(_hr) throw std::runtime_error(#_hr); + #define DMLX_THROW(_hr) throw std::runtime_error(#_hr); #endif #else #define DMLX_THROW_IF_FAILED(_hr) if (FAILED(_hr)) { std::abort(); } - #define DMLX_THROW(_hr) { std::abort(); } + #define DMLX_THROW(_hr) { std::abort(); } #endif class Graph; @@ -307,7 +307,7 @@ namespace dml // (0, 2, ..., n, 1). This is often referred to as "NHWC" or "interleaved channel" layout. This is useful, // for example, when applied to 2D Convolution to produce outputs in an NHWC layout (as opposed to NCHW, which // is the DirectML default for 2D Convolution). - // + // // Examples of the transposes produced by this policy: // NCW -> NWC // NCHW -> NHWC @@ -713,7 +713,7 @@ namespace dml // Represents an activation to be fused with an existing operator. The meaning of param1 and param2 depend on the // activation to be fused. - // + // // For HARD_SIGMOID, LINEAR, PARAMETRIC_SOFTPLUS, and SCALED_TANH: param1 = Alpha and param2 = Beta // For ELU, LEAKY_RELU, THRESHOLDED_RELU, and CELU: param1 = Alpha. param2 is unused. // For SCALED_ELU, param1 = Alpha and param2 = Gamma. @@ -1858,13 +1858,13 @@ namespace dml } // Helper for setting parameters for the Convolution operator. Sample usage: - // + // // auto conv = dml::ConvolutionBuilder(...) // .StartPadding(...) // .EndPadding(...) // .Strides(...) // .Build(); - // + // // Parameters left unspecified will be defaulted with the same values as dml::Convolution(). class ConvolutionBuilder { @@ -2114,9 +2114,9 @@ namespace dml return output; } - // + // // TODO: LpPooling - // + // // --------------------------------------------------------------------------------------------------------------- @@ -2203,13 +2203,13 @@ namespace dml } // Helper for setting parameters for the MaxPooling operator. Sample usage: - // + // // auto [out, outIndices] = dml::MaxPoolingBuilder(...) // .StartPadding(...) // .EndPadding(...) // .OutputIndices(...) // .Build(); - // + // // Parameters left unspecified will be defaulted with the same values as dml::MaxPooling(). class MaxPoolingBuilder { @@ -2251,13 +2251,13 @@ namespace dml // --------------------------------------------------------------------------------------------------------------- - // + // // TODO: MaxUnpooling - // + // - // + // // TODO: ROIPooling - // + // inline Expression Slice( Expression input, @@ -2683,7 +2683,7 @@ namespace dml { detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); TensorDesc inputTensor = input.Impl()->GetOutputDesc(); - + assert(inputTensor.sizes.size() == 4); dml::TensorDesc::Dimensions outputSizes = { @@ -2691,7 +2691,7 @@ namespace dml inputTensor.sizes[1] * blockSize * blockSize, inputTensor.sizes[2] / blockSize, inputTensor.sizes[3] / blockSize - }; + }; TensorDesc outputTensor(inputTensor.dataType, outputSizes, builder->GetTensorPolicy()); @@ -2715,7 +2715,7 @@ namespace dml { detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder(); TensorDesc inputTensor = input.Impl()->GetOutputDesc(); - + assert(inputTensor.sizes.size() == 4); dml::TensorDesc::Dimensions outputSizes = { @@ -2771,7 +2771,7 @@ namespace dml struct TopKOutputs { Expression value; - Expression index; + Expression index; }; inline TopKOutputs TopK(Expression input, uint32_t axis, uint32_t k, DML_AXIS_DIRECTION axisDirection) @@ -2909,14 +2909,14 @@ namespace dml desc.VarianceTensor = varianceTensor.AsPtr(); desc.ScaleTensor = scaleTensor.AsPtr(); desc.Epsilon = epsilon; - + desc.OutputGradientTensor = outputGradientTensor.AsPtr(); desc.OutputScaleGradientTensor = outputScaleGradientTensor.AsPtr(); desc.OutputBiasGradientTensor = outputBiasGradientTensor.AsPtr(); - + dml::detail::NodeOutput* const inputs[] = { input.Impl(), inputGradient.Impl(), mean.Impl(), variance.Impl(), scale.Impl() }; dml::detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_BATCH_NORMALIZATION_GRAD, &desc, inputs); - + BatchNormalizationGradOutputs outputValues; outputValues.gradient = builder->CreateNodeOutput(node, 0, *desc.OutputGradientTensor); outputValues.scaleGradient = builder->CreateNodeOutput(node, 1, *desc.OutputScaleGradientTensor); @@ -2932,7 +2932,7 @@ namespace dml { Expression output; Expression mean; - Expression variance; + Expression variance; }; inline BatchNormalizationTrainingOutputs BatchNormalizationTraining( @@ -3005,14 +3005,14 @@ namespace dml desc.VarianceTensor = varianceTensor.AsPtr(); desc.ScaleTensor = scaleTensor.AsPtr(); desc.Epsilon = epsilon; - + desc.OutputGradientTensor = outputGradientTensor.AsPtr(); desc.OutputScaleGradientTensor = outputScaleGradientTensor.AsPtr(); desc.OutputBiasGradientTensor = outputBiasGradientTensor.AsPtr(); - + dml::detail::NodeOutput* const inputs[] = { input.Impl(), inputGradient.Impl(), mean.Impl(), variance.Impl(), scale.Impl() }; dml::detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD, &desc, inputs); - + BatchNormalizationGradOutputs outputValues; outputValues.gradient = builder->CreateNodeOutput(node, 0, *desc.OutputGradientTensor); outputValues.scaleGradient = builder->CreateNodeOutput(node, 1, *desc.OutputScaleGradientTensor); @@ -3099,17 +3099,17 @@ namespace dml return output; } - // + // // TODO: LpNormalization - // + // - // + // // TODO: RNN - // + // - // + // // TODO: LSTM - // + // enum class GRUOutputOptions { @@ -3121,7 +3121,7 @@ namespace dml struct GRUOutputs { Expression sequence; - Expression single; + Expression single; }; inline GRUOutputs GRU( @@ -3230,7 +3230,7 @@ namespace dml return { outputSequenceExpr, outputSingleExpr }; } - // + // // TODO: DiagonalMatrix // @@ -3442,33 +3442,33 @@ namespace dml return output; } - // + // // TODO: MatrixMultiplyInteger - // + // - // + // // TODO: QuantizedLinearMatrixMultiply - // + // - // + // // TODO: ConvolutionInteger - // + // - // + // // TODO: QuantizedLinearConvolution - // + // - // + // // TODO: ReluGrad - // + // - // + // // TODO: AveragePoolingGrad - // + // - // + // // TODO: MaxPoolingGrad - // + // struct RandomGeneratorOutputs { @@ -3496,7 +3496,7 @@ namespace dml // Input and output state have the same TensorDesc. desc.OutputStateTensor = inputStateTensor.AsPtr(); } - + RandomGeneratorOutputs out; detail::NodeOutput* const inputs[] = { inputState.Impl() }; @@ -3537,7 +3537,7 @@ namespace dml desc.InputTensor = inputTensor.AsPtr(); desc.OutputCountTensor = outputCountTensor.AsPtr(); desc.OutputCoordinatesTensor = outputCoordinatesTensor.AsPtr(); - + NonZeroCoordinatesOutputs output; detail::NodeOutput* const inputs[] = { input.Impl() }; @@ -3640,17 +3640,17 @@ namespace dml return output; } - // + // // TODO: AdamOptimizer - // + // - // + // // TODO: Argmin - // + // - // + // // TODO: Argmax - // + // #if DML_TARGET_VERSION >= 0x4000 @@ -3694,7 +3694,7 @@ namespace dml desc.ROITensor = roiTensor.AsPtr(); desc.BatchIndicesTensor = batchIndicesTensor.AsPtr(); desc.OutputTensor = outputTensor.AsPtr(); - desc.ReductionFunction = reductionFunction; + desc.ReductionFunction = reductionFunction; desc.InterpolationMode = interpolationMode; desc.SpatialScaleX = spatialScaleX; desc.SpatialScaleY = spatialScaleY; @@ -3763,7 +3763,7 @@ namespace dml outputGradientTensor = TensorDesc(inputGradientTensor.dataType, outputGradientSizes, builder->GetTensorPolicy()); } - + TensorDesc outputROIGradientTensor = computeOutputROIGradient ? TensorDesc(roiTensor.dataType, roiTensor.sizes, builder->GetTensorPolicy()) : TensorDesc(); assert(!computeOutputROIGradient || outputROIGradientTensor.sizes == roiTensor.sizes); @@ -3774,7 +3774,7 @@ namespace dml desc.BatchIndicesTensor = batchIndicesTensor.AsPtr(); desc.OutputGradientTensor = computeOutputGradient ? outputGradientTensor.AsPtr() : nullptr; desc.OutputROIGradientTensor = computeOutputROIGradient ? outputROIGradientTensor.AsPtr() : nullptr; - desc.ReductionFunction = reductionFunction; + desc.ReductionFunction = reductionFunction; desc.InterpolationMode = interpolationMode; desc.SpatialScaleX = spatialScaleX; desc.SpatialScaleY = spatialScaleY; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index aaf02ca146..833871de0b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -1418,6 +1418,29 @@ inline std::vector GetFields(const DML_DIAGONAL_MATRIX1_OPERATOR_ OperatorField(&DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.DiagonalFillEnd))), }; } +inline std::vector GetFields(const DML_MULTIHEAD_ATTENTION_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.QueryTensor))), + OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.KeyTensor))), + OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ValueTensor))), + OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.StackedQueryKeyTensor))), + OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.StackedKeyValueTensor))), + OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StackedQueryKeyValueTensor))), + OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.BiasTensor))), + OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.MaskTensor))), + OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.RelativePositionBiasTensor))), + OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.PastKeyTensor))), + OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.PastValueTensor))), + OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.OutputPresentKeyTensor))), + OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[13], ToOperatorFieldType(static_cast(desc.OutputPresentValueTensor))), + OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[14], ToOperatorFieldType(static_cast(desc.Scale))), + OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[15], ToOperatorFieldType(static_cast(desc.MaskFilterValue))), + OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[16], ToOperatorFieldType(static_cast(desc.HeadCount))), + OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[17], ToOperatorFieldType(static_cast(desc.MaskType))), + }; +} inline std::vector GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc) { return { @@ -1753,6 +1776,7 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_RESAMPLE2: return DML_RESAMPLE2_OPERATOR_SCHEMA; case DML_OPERATOR_RESAMPLE_GRAD1: return DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA; case DML_OPERATOR_DIAGONAL_MATRIX1: return DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA; + case DML_OPERATOR_MULTIHEAD_ATTENTION: return DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_CELU: return DML_ACTIVATION_CELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA; @@ -2346,6 +2370,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_MULTIHEAD_ATTENTION: + return AbstractOperatorDesc( + &DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_ACTIVATION_ELU: return AbstractOperatorDesc( &DML_ACTIVATION_ELU_OPERATOR_SCHEMA, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h new file mode 100644 index 0000000000..c63863853f --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h @@ -0,0 +1,988 @@ +#pragma once + +#include "../../../OperatorAuthorHelper/OperatorHelper.h" +#include "../MLOperatorAuthorImpl.h" + +#include "../External/D3DX12/d3dx12.h" +#include + +// NOTE: When this operator's implementation is moved into DML, the associated FP16 fallback +// should be removed from IsCustomOpShader(...) in +// onnxruntime\core\providers\dml\DmlExecutionProvider\src\ExecutionProvider.cpp + +// The shader headers are produced using "GeneratedShaders/GenerateShaders.bat" +namespace GridSample_uint16_float +{ + #include "GeneratedShaders/grid_sample_uint16_float.h" +} + +namespace GridSample_uint_float +{ + #include "GeneratedShaders/grid_sample_uint_float.h" +} + +namespace GridSample_uint64_float +{ + #include "GeneratedShaders/grid_sample_uint64_float.h" +} + +namespace GridSample_int16_float +{ + #include "GeneratedShaders/grid_sample_int16_float.h" +} + +namespace GridSample_int_float +{ + #include "GeneratedShaders/grid_sample_int_float.h" +} + +namespace GridSample_int64_float +{ + #include "GeneratedShaders/grid_sample_int64_float.h" +} + +namespace GridSample_fp16_float +{ + #include "GeneratedShaders/grid_sample_fp16_float.h" +} + +namespace GridSample_float_float +{ + #include "GeneratedShaders/grid_sample_float_float.h" +} + +namespace GridSample_double_float +{ + #include "GeneratedShaders/grid_sample_double_float.h" +} + +namespace GridSample_bool_float +{ + #include "GeneratedShaders/grid_sample_bool_float.h" +} + +namespace GridSample_uint16_fp16 +{ + #include "GeneratedShaders/grid_sample_uint16_fp16.h" +} + +namespace GridSample_uint_fp16 +{ + #include "GeneratedShaders/grid_sample_uint_fp16.h" +} + +namespace GridSample_uint64_fp16 +{ + #include "GeneratedShaders/grid_sample_uint64_fp16.h" +} + +namespace GridSample_int16_fp16 +{ + #include "GeneratedShaders/grid_sample_int16_fp16.h" +} + +namespace GridSample_int_fp16 +{ + #include "GeneratedShaders/grid_sample_int_fp16.h" +} + +namespace GridSample_int64_fp16 +{ + #include "GeneratedShaders/grid_sample_int64_fp16.h" +} + +namespace GridSample_fp16_fp16 +{ + #include "GeneratedShaders/grid_sample_fp16_fp16.h" +} + +namespace GridSample_float_fp16 +{ + #include "GeneratedShaders/grid_sample_float_fp16.h" +} + +namespace GridSample_double_fp16 +{ + #include "GeneratedShaders/grid_sample_double_fp16.h" +} + +namespace GridSample_bool_fp16 +{ + #include "GeneratedShaders/grid_sample_bool_fp16.h" +} + +namespace GridSample_uint16_double +{ + #include "GeneratedShaders/grid_sample_uint16_double.h" +} + +namespace GridSample_uint_double +{ + #include "GeneratedShaders/grid_sample_uint_double.h" +} + +namespace GridSample_uint64_double +{ + #include "GeneratedShaders/grid_sample_uint64_double.h" +} + +namespace GridSample_int16_double +{ + #include "GeneratedShaders/grid_sample_int16_double.h" +} + +namespace GridSample_int_double +{ + #include "GeneratedShaders/grid_sample_int_double.h" +} + +namespace GridSample_int64_double +{ + #include "GeneratedShaders/grid_sample_int64_double.h" +} + +namespace GridSample_fp16_double +{ + #include "GeneratedShaders/grid_sample_fp16_double.h" +} + +namespace GridSample_float_double +{ + #include "GeneratedShaders/grid_sample_float_double.h" +} + +namespace GridSample_double_double +{ + #include "GeneratedShaders/grid_sample_double_double.h" +} + +namespace GridSample_bool_double +{ + #include "GeneratedShaders/grid_sample_bool_double.h" +} + + +#include +#include + +#include + +using namespace Microsoft::WRL; + +enum DmlGridSampleKernelInputIndex : uint32_t +{ + Input, + Grid, +}; + +enum DmlGridSampleMode : uint32_t +{ + Bilinear, + Nearest, + Bicubic, +}; + +enum DmlGridSamplePaddingMode : uint32_t +{ + Zeros, + Border, + Reflection +}; + +// Helper to derive dimensions and attributes from either the GridSample shape inferrer or the GridSample kernel constructor. +struct DmlGridSampleParameters +{ + uint32_t batchSize = 0; + uint32_t channelSize = 0; + uint32_t height = 0; + uint32_t width = 0; + int64_t alignCorners = 0; + DmlGridSampleMode mode = DmlGridSampleMode::Bilinear; + DmlGridSamplePaddingMode paddingMode = DmlGridSamplePaddingMode::Zeros; + + DML_TENSOR_DATA_TYPE dataType = DML_TENSOR_DATA_TYPE_UNKNOWN; + + DmlGridSampleParameters(){} + + DmlGridSampleParameters( + const OperatorHelper::IKernelInformationAdapter& kernelInfo, + const OperatorHelper::IShapeInformationAdapter& shapeInfo) + { + auto& attributes = kernelInfo.GetAttributes(); + + alignCorners = attributes.GetOptionalAttribute(AttrName::AlignCorners, 0); + + std::string str_attrib = attributes.GetOptionalAttribute(AttrName::Mode, "bilinear"); + ML_CHECK_VALID_ARGUMENT(str_attrib == "bilinear" || str_attrib == "nearest" || str_attrib == "bicubic"); + if (str_attrib == "bilinear") + { + mode = DmlGridSampleMode::Bilinear; + } + else if (str_attrib == "nearest") + { + mode = DmlGridSampleMode::Nearest; + } + else if (str_attrib == "bicubic") + { + mode = DmlGridSampleMode::Bicubic; + } + + str_attrib = attributes.GetOptionalAttribute(AttrName::PaddingMode, "zeros"); + ML_CHECK_VALID_ARGUMENT(str_attrib == "zeros" || str_attrib == "border" || str_attrib == "reflection"); + if (str_attrib == "zeros") + { + paddingMode = DmlGridSamplePaddingMode::Zeros; + } + else if (str_attrib == "border") + { + paddingMode = DmlGridSamplePaddingMode::Border; + } + else if (str_attrib == "reflection") + { + paddingMode = DmlGridSamplePaddingMode::Reflection; + } + + // input 0: signal (required; tensor) + { + // Input shape is expected to be [batch_size, channels, height, width] + // 4-D tensor of shape (N, C, H_out, W_out) of sampled values. + // For integer input types, intermediate values are computed as floating point and cast to integer at the end. uint32_t rank = shapeInfo.GetInputTensorDimensionCount(DmlGridSampleKernelInputIndex::Input); + uint32_t rank = shapeInfo.GetInputTensorDimensionCount(DmlGridSampleKernelInputIndex::Input); + ML_CHECK_VALID_ARGUMENT(rank == 4, "Input shape must be 4D."); + + auto dims = shapeInfo.GetInputTensorShape(DmlGridSampleKernelInputIndex::Input); + assert(dims.size() == rank); + this->batchSize = dims[0]; + this->channelSize = dims[1]; + + MLOperatorEdgeDescription edgeDesc = kernelInfo.GetInputEdgeDescription(DmlGridSampleKernelInputIndex::Input); + + assert(edgeDesc.edgeType == MLOperatorEdgeType::Tensor); + this->dataType = Dml::GetDmlDataTypeFromMlDataType(edgeDesc.tensorDataType); + } + + // input 1: grid (required; tensor) + { + // Grid shape is expected to be [batch_size, height_out, width_out, 2] + uint32_t rank = shapeInfo.GetInputTensorDimensionCount(DmlGridSampleKernelInputIndex::Grid); + ML_CHECK_VALID_ARGUMENT(rank == 4, "Input shape must be 4D."); + + auto dims = shapeInfo.GetInputTensorShape(DmlGridSampleKernelInputIndex::Grid); + assert(dims.size() == rank); + this->height = dims[1]; + this->width = dims[2]; + } + } + +}; + +namespace GridSampleHelpers +{ + // Divides and rounds + inline uint32_t CeilDivide(uint32_t dividend, uint32_t divisor) + { + uint64_t temp = static_cast(dividend) + divisor - 1; + return static_cast(temp / divisor); + } + + // Gets the next number of elements to dispatch to the GPU within a loop handling a large + // total number of tensor elements and threads. + void GetNextDispatchSize( + uint32_t elementCount, + uint32_t elementsPerThread, + uint32_t numThreads, + _Out_ uint32_t& dispatch, + _Out_ uint32_t& pendingElementCount + ) + { + // Max threads per workgroup is 2^10 (1024). Max dispatch per dimension is 2^16. Taken together, we can dispatch a maximum of + // 2^26 (268,435,456) threads along a single dimension. This should suffice for a majority of the workload. Therefore, even + // though it is possible to dispatch up to (2^16)^3 workgroups simultaneously, we stick to the simpler 1D dispatch alternative. + assert(numThreads <= D3D12_CS_THREAD_GROUP_MAX_THREADS_PER_GROUP); + + const uint32_t maxThreadsPerDispatch = numThreads * D3D12_CS_DISPATCH_MAX_THREAD_GROUPS_PER_DIMENSION; + + const uint32_t requiredThreadCount = CeilDivide(elementCount, elementsPerThread); + + // Compute max dispatchable elements + const uint32_t availableThreadCount = std::min(requiredThreadCount, maxThreadsPerDispatch); + + // Compute required thread group count + uint32_t workGroupCount1D = CeilDivide(availableThreadCount, numThreads); + + // Compute min dispatch size + dispatch = workGroupCount1D; + + // With the dispatch size computed, compute the dispatched element count + const uint32_t dispatchedElementCount = workGroupCount1D * numThreads * elementsPerThread; + + // Update the pending element count + pendingElementCount = (dispatchedElementCount < elementCount) ? elementCount - dispatchedElementCount : 0; + } +} + +class DmlGridSampleOperator : public WRL::Base +{ +private: + ComPtr m_device; + ComPtr m_gridSampleRootSignature; + ComPtr m_gridSamplePipelineState; + DmlGridSampleParameters m_params = {}; + + + // Allocate temporary buffers if needed + struct ResourceDesc + { + ComPtr Resource; + std::array Sizes; + std::array Strides; + }; + + struct GridSampleShaderConstants + { + uint32_t StartIndex; + uint32_t ElementCount; + uint32_t Mode; + uint32_t PaddingMode; + uint32_t InputSizes[4]; + uint32_t InputStrides[4]; + uint32_t GridSizes[4]; + uint32_t GridStrides[4]; + uint32_t OutputSizes[4]; + uint32_t OutputStrides[4]; + uint32_t AlignCorners; + }; + +public: + + DmlGridSampleOperator(IMLOperatorKernelCreationContext* context) + { + ComPtr executionObject; + context->GetExecutionInterface(executionObject.GetAddressOf()); + + ComPtr commandList; + executionObject.As(&commandList); + + ORT_THROW_IF_FAILED(commandList->GetDevice(IID_ID3D12Device, &m_device)); + + MLOperatorKernelCreationContext creationContext(context); + OperatorHelper::KernelInformationAdapter kernelInfo{creationContext}; + OperatorHelper::ShapeInformationAdapter shapeInfo{creationContext}; + m_params = DmlGridSampleParameters(kernelInfo, shapeInfo); + + MLOperatorEdgeDescription inputEdgeDesc; + ORT_THROW_IF_FAILED(context->GetInputEdgeDescription(0, &inputEdgeDesc)); + assert(inputEdgeDesc.edgeType == MLOperatorEdgeType::Tensor); + + MLOperatorEdgeDescription gridEdgeDesc; + ORT_THROW_IF_FAILED(context->GetInputEdgeDescription(0, &gridEdgeDesc)); + assert(gridEdgeDesc.edgeType == MLOperatorEdgeType::Tensor); + + PrepareGridSample(inputEdgeDesc.tensorDataType, gridEdgeDesc.tensorDataType); + } + + void PrepareGridSample(MLOperatorTensorDataType inputDataType, MLOperatorTensorDataType gridDataType) + { + // Compute root signature. + const int uavCount = 3; // 3 bound UAVs: input, grid & output + std::vector rootParameters; + rootParameters.resize(uavCount + 1); + + for (uint32_t i = 0; i < uavCount; i++) + { + rootParameters[i].InitAsUnorderedAccessView(i); + } + + // cbuffer Constants + // { + // uint StartIndex; + // uint ElementCount; + // uint Mode; + // uint PaddingMode; + // uint4 InputSizes; + // uint4 InputStrides; + // uint4 GridSizes; + // uint4 GridStrides; + // uint4 OutputSizes; + // uint4 OutputStrides; + // uint AlignCorners; + // }; + int constantCount = 29; + rootParameters[uavCount].InitAsConstants(constantCount, 0); + + CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC desc; + desc.Init_1_1(static_cast(rootParameters.size()), rootParameters.data()); + + ComPtr rootSignatureBlob; + ComPtr rootSignatureErrorBlob; + ORT_THROW_IF_FAILED(D3D12SerializeVersionedRootSignature( + &desc, + rootSignatureBlob.GetAddressOf(), + rootSignatureErrorBlob.GetAddressOf() + )); + + ORT_THROW_IF_FAILED(m_device->CreateRootSignature( + 0, + rootSignatureBlob->GetBufferPointer(), + rootSignatureBlob->GetBufferSize(), + IID_ID3D12RootSignature, + &m_gridSampleRootSignature + )); + + // Describe and create the compute pipeline state object (PSO). + D3D12_COMPUTE_PIPELINE_STATE_DESC computePsoDesc = {}; + computePsoDesc.pRootSignature = m_gridSampleRootSignature.Get(); + + switch (gridDataType) + { + case MLOperatorTensorDataType::Float: + { + switch (inputDataType) + { + case MLOperatorTensorDataType::UInt16: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint16_float::g_GridSample, sizeof(GridSample_uint16_float::g_GridSample)); + break; + + case MLOperatorTensorDataType::UInt32: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint_float::g_GridSample, sizeof(GridSample_uint_float::g_GridSample)); + break; + + case MLOperatorTensorDataType::UInt64: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint64_float::g_GridSample, sizeof(GridSample_uint64_float::g_GridSample)); + break; + + case MLOperatorTensorDataType::Int16: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int16_float::g_GridSample, sizeof(GridSample_int16_float::g_GridSample)); + break; + + case MLOperatorTensorDataType::Int32: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int_float::g_GridSample, sizeof(GridSample_int_float::g_GridSample)); + break; + + case MLOperatorTensorDataType::Int64: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int64_float::g_GridSample, sizeof(GridSample_int64_float::g_GridSample)); + break; + + case MLOperatorTensorDataType::Float16: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_fp16_float::g_GridSample, sizeof(GridSample_fp16_float::g_GridSample)); + break; + + case MLOperatorTensorDataType::Float: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_float_float::g_GridSample, sizeof(GridSample_float_float::g_GridSample)); + break; + + case MLOperatorTensorDataType::Double: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_double_float::g_GridSample, sizeof(GridSample_double_float::g_GridSample)); + break; + + case MLOperatorTensorDataType::Bool: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_bool_float::g_GridSample, sizeof(GridSample_bool_float::g_GridSample)); + break; + + default: + ORT_THROW_HR(E_INVALIDARG); + } + break; + } + case MLOperatorTensorDataType::Float16: + { + switch (inputDataType) + { + case MLOperatorTensorDataType::UInt16: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint16_fp16::g_GridSample, sizeof(GridSample_uint16_fp16::g_GridSample)); + break; + + case MLOperatorTensorDataType::UInt32: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint_fp16::g_GridSample, sizeof(GridSample_uint_fp16::g_GridSample)); + break; + + case MLOperatorTensorDataType::UInt64: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint64_fp16::g_GridSample, sizeof(GridSample_uint64_fp16::g_GridSample)); + break; + + case MLOperatorTensorDataType::Int16: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int16_fp16::g_GridSample, sizeof(GridSample_int16_fp16::g_GridSample)); + break; + + case MLOperatorTensorDataType::Int32: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int_fp16::g_GridSample, sizeof(GridSample_int_fp16::g_GridSample)); + break; + + case MLOperatorTensorDataType::Int64: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int64_fp16::g_GridSample, sizeof(GridSample_int64_fp16::g_GridSample)); + break; + + case MLOperatorTensorDataType::Float16: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_fp16_fp16::g_GridSample, sizeof(GridSample_fp16_fp16::g_GridSample)); + break; + + case MLOperatorTensorDataType::Float: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_float_fp16::g_GridSample, sizeof(GridSample_float_fp16::g_GridSample)); + break; + + case MLOperatorTensorDataType::Double: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_double_fp16::g_GridSample, sizeof(GridSample_double_fp16::g_GridSample)); + break; + + case MLOperatorTensorDataType::Bool: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_bool_fp16::g_GridSample, sizeof(GridSample_bool_fp16::g_GridSample)); + break; + + default: + ORT_THROW_HR(E_INVALIDARG); + } + break; + } + case MLOperatorTensorDataType::Double: + { + switch (inputDataType) + { + case MLOperatorTensorDataType::UInt16: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint16_double::g_GridSample, sizeof(GridSample_uint16_double::g_GridSample)); + break; + + case MLOperatorTensorDataType::UInt32: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint_double::g_GridSample, sizeof(GridSample_uint_double::g_GridSample)); + break; + + case MLOperatorTensorDataType::UInt64: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_uint64_double::g_GridSample, sizeof(GridSample_uint64_double::g_GridSample)); + break; + + case MLOperatorTensorDataType::Int16: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int16_double::g_GridSample, sizeof(GridSample_int16_double::g_GridSample)); + break; + + case MLOperatorTensorDataType::Int32: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int_double::g_GridSample, sizeof(GridSample_int_double::g_GridSample)); + break; + + case MLOperatorTensorDataType::Int64: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_int64_double::g_GridSample, sizeof(GridSample_int64_double::g_GridSample)); + break; + + case MLOperatorTensorDataType::Float16: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_fp16_double::g_GridSample, sizeof(GridSample_fp16_double::g_GridSample)); + break; + + case MLOperatorTensorDataType::Float: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_float_double::g_GridSample, sizeof(GridSample_float_double::g_GridSample)); + break; + + case MLOperatorTensorDataType::Double: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_double_double::g_GridSample, sizeof(GridSample_double_double::g_GridSample)); + break; + + case MLOperatorTensorDataType::Bool: + computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(GridSample_bool_double::g_GridSample, sizeof(GridSample_bool_double::g_GridSample)); + break; + + default: + ORT_THROW_HR(E_INVALIDARG); + } + break; + } + default: + ORT_THROW_HR(E_INVALIDARG); + } + + ORT_THROW_IF_FAILED(m_device->CreateComputePipelineState(&computePsoDesc, IID_ID3D12PipelineState, &m_gridSamplePipelineState)); + } + + // Computes the outputs of the kernel. This may be called multiple times + // simultaneously within the same instance of the class. Implementations + // of this method must be thread-safe. + STDMETHOD(Compute)(IMLOperatorKernelContext* context) + { + try + { + // Get the input tensor + ComPtr inputTensor; + ORT_THROW_IF_FAILED(context->GetInputTensor(0, inputTensor.GetAddressOf())); + + // Get the grid tensor + ComPtr gridTensor; + ORT_THROW_IF_FAILED(context->GetInputTensor(1, gridTensor.GetAddressOf())); + + // Get the output tensor + ComPtr outputTensor; + context->GetOutputTensor(0, outputTensor.GetAddressOf()); + + if (outputTensor->IsCpuData() || inputTensor->IsCpuData() || gridTensor->IsCpuData()) + { + return E_UNEXPECTED; + } + + ComPtr executionObject; + ComPtr commandList; + context->GetExecutionInterface(executionObject.GetAddressOf()); + executionObject.As(&commandList); + + // Get the input and output shape sizes + auto inputDims = GetTensorDimensions(inputTensor.Get()); + auto gridDims = GetTensorDimensions(gridTensor.Get()); + auto outputDims = GetTensorDimensions(outputTensor.Get()); + + ComPtr inputUnknown; + ComPtr inputResource; + inputTensor->GetDataInterface(inputUnknown.GetAddressOf()); + ORT_THROW_IF_FAILED(inputUnknown.As(&inputResource)); + + ComPtr gridUnknown; + ComPtr gridResource; + gridTensor->GetDataInterface(gridUnknown.GetAddressOf()); + ORT_THROW_IF_FAILED(gridUnknown.As(&gridResource)); + + ComPtr outputUnknown; + ComPtr outputResource; + outputTensor->GetDataInterface(outputUnknown.GetAddressOf()); + ORT_THROW_IF_FAILED(outputUnknown.As(&outputResource)); + + return Compute( + commandList.Get(), + context, + inputResource.Get(), + inputDims, + gridResource.Get(), + gridDims, + outputResource.Get(), + outputDims + ); + } + catch (...) + { + return E_FAIL; + } + + return S_OK; + } + + HRESULT Compute( + ID3D12GraphicsCommandList* commandList, + IMLOperatorKernelContext* context, + ID3D12Resource* inputResource, + gsl::span inputDims, + ID3D12Resource* gridResource, + gsl::span gridDims, + ID3D12Resource* outputResource, + gsl::span outputDims) + { + try + { + GridSample( + inputResource, + inputDims, + gridResource, + gridDims, + outputResource, + outputDims, + commandList); + } + catch (...) + { + return E_FAIL; + } + + return S_OK; + } + + void GridSample( + ID3D12Resource* inputResource, + gsl::span inputDims, + ID3D12Resource* gridResource, + gsl::span gridDims, + ID3D12Resource* outputResource, + gsl::span outputDims, + ID3D12GraphicsCommandList* commandList) + { + std::array inputStrides; + std::array gridStrides; + std::array outputStrides; + Dml::GetDescendingPackedStrides(inputDims, inputStrides); + Dml::GetDescendingPackedStrides(gridDims, gridStrides); + Dml::GetDescendingPackedStrides(outputDims, outputStrides); + + // Transition resources from common to UAV state + D3D12_RESOURCE_BARRIER barriers[3]; + + barriers[0] = CD3DX12_RESOURCE_BARRIER::Transition( + inputResource, + D3D12_RESOURCE_STATE_COMMON, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS + ); + + barriers[1] = CD3DX12_RESOURCE_BARRIER::Transition( + gridResource, + D3D12_RESOURCE_STATE_COMMON, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS + ); + + barriers[2] = CD3DX12_RESOURCE_BARRIER::Transition( + outputResource, + D3D12_RESOURCE_STATE_COMMON, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS + ); + + inputResource->SetName(L"InputResource"); + outputResource->SetName(L"OutputResource"); + gridResource->SetName(L"GridResource"); + + commandList->ResourceBarrier(3, barriers); + + // Set the root signature and pipeline state + commandList->SetComputeRootSignature(m_gridSampleRootSignature.Get()); + commandList->SetPipelineState(m_gridSamplePipelineState.Get()); + + // Each iteration of the below loop represents 1 level in the Stockham DFT + // Dispatch in a loop + GridSampleShaderConstants constants = {}; + constants.AlignCorners = static_cast(m_params.alignCorners); + constants.Mode = static_cast(m_params.mode); + constants.PaddingMode = static_cast(m_params.paddingMode); + std::copy(inputDims.begin(), inputDims.end(), constants.InputSizes); + std::copy(inputStrides.begin(), inputStrides.end(), constants.InputStrides); + std::copy(gridDims.begin(), gridDims.end(), constants.GridSizes); + std::copy(gridStrides.begin(), gridStrides.end(), constants.GridStrides); + std::copy(outputDims.begin(), outputDims.end(), constants.OutputSizes); + std::copy(outputStrides.begin(), outputStrides.end(), constants.OutputStrides); + + constants.ElementCount = ComputeElementCountFromDimensions(constants.OutputSizes); + std::array uav_resources = { inputResource, gridResource, outputResource }; + Dispatch(uav_resources, constants, commandList); + + // Transition resources to common state + barriers[0] = CD3DX12_RESOURCE_BARRIER::Transition( + inputResource, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COMMON + ); + + barriers[1] = CD3DX12_RESOURCE_BARRIER::Transition( + gridResource, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COMMON + ); + + barriers[2] = CD3DX12_RESOURCE_BARRIER::Transition( + outputResource, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COMMON + ); + + commandList->ResourceBarrier(3, barriers); + } + + std::vector GetTensorDimensions(IMLOperatorTensor* tensor) + { + auto inputDimsSize = tensor->GetDimensionCount(); + auto dims = std::vector(inputDimsSize); + ORT_THROW_IF_FAILED(tensor->GetShape(static_cast(dims.size()), dims.data())); + return dims; + } + + template + void Dispatch( + std::array& resources, + TConstants& constants, + ID3D12GraphicsCommandList* commandList) + { + D3D12_RESOURCE_BARRIER uav_barriers[TSize]; + + std::transform( + resources.begin(), resources.end(), + uav_barriers, + [](auto& resource) { return CD3DX12_RESOURCE_BARRIER::UAV(resource); } ); + commandList->ResourceBarrier(TSize, uav_barriers); + + for (uint32_t i = 0; i < TSize; i++) + { + // Set resource views + if (resources[i]) { + commandList->SetComputeRootUnorderedAccessView( + i, // root parameter index + resources[i]->GetGPUVirtualAddress() + ); + } + else + { + commandList->SetComputeRootUnorderedAccessView( + i, // root parameter index + {} + ); + + } + } + + auto pendingElementCount = constants.ElementCount; + + // Dispatch up to the maximum number of threads per iteration until + // all elements are completed + while (pendingElementCount > 0) + { + constants.StartIndex = constants.ElementCount - pendingElementCount; + + uint32_t dispatchSizeX; + + GridSampleHelpers::GetNextDispatchSize( + pendingElementCount, + 1, + 64, + dispatchSizeX, + pendingElementCount + ); + + // Set root constants + commandList->SetComputeRoot32BitConstants( + TSize, // root parameter index + 29, // Constant count + &constants, + 0 // offset + ); + + commandList->Dispatch(dispatchSizeX, 1, 1); + } + + commandList->ResourceBarrier(2, uav_barriers); + } +}; + +struct GridSampleShapeInferrer : public WRL::Base +{ + STDMETHOD(InferOutputShapes)(IMLOperatorShapeInferenceContext* context) noexcept + { + try + { + ComPtr contextPrivate; + ORT_THROW_IF_FAILED(context->QueryInterface(IID_PPV_ARGS(&contextPrivate))); + + MLShapeInferenceContext inferenceContext(context); + OperatorHelper::KernelInformationAdapter kernelInfo{inferenceContext}; + OperatorHelper::ShapeInformationAdapter shapeInfo{inferenceContext}; + DmlGridSampleParameters params(kernelInfo, shapeInfo); + + std::array outputDims = { params.batchSize, params.channelSize, params.height, params.width }; + + ORT_THROW_IF_FAILED(context->SetOutputTensorShape(0, onnxruntime::narrow(outputDims.size()), outputDims.data())); + } + catch (...) + { + return E_FAIL; + } + + return S_OK; + } +}; + +class DmlGridSampleOperatorFactory : public WRL::Base +{ +public: + STDMETHOD(CreateKernel)( + IMLOperatorKernelCreationContext* context, + IMLOperatorKernel** kernel) + { + try + { + auto dftOperator = wil::MakeOrThrow(context); + dftOperator.CopyTo(kernel); + return S_OK; + } + catch (...) + { + return E_FAIL; + } + } + + static void RegisterGridSampleKernel(IMLOperatorRegistry* registry) + { + MLOperatorKernelDescription kernelDescription = {}; + kernelDescription.domain = ""; + kernelDescription.name = "GridSample"; + kernelDescription.minimumOperatorSetVersion = 16; + kernelDescription.executionType = MLOperatorExecutionType::D3D12; + + // T1: tensor(float16), tensor(float), tensor(double), tensor(bfloat16) + MLOperatorEdgeTypeConstrant t1Constraint; + t1Constraint.typeLabel = "T1"; + std::vector t1AllowedEdges + { + MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Float }, + MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Float16 }, + MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Int8 }, + MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Int16 }, + MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Int32 }, + MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Int64 }, + MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::UInt8 }, + MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::UInt16 }, + MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::UInt32 }, + MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::UInt64 }, + }; + t1Constraint.allowedTypes = t1AllowedEdges.data(); + t1Constraint.allowedTypeCount = static_cast(t1AllowedEdges.size()); + + // T2 : tensor(int32), tensor(int64) + MLOperatorEdgeTypeConstrant t2Constraint; + t2Constraint.typeLabel = "T2"; + std::vector t2AllowedEdges + { + MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Float16 }, + MLOperatorEdgeDescription { MLOperatorEdgeType::Tensor, (uint64_t)MLOperatorTensorDataType::Float }, + }; + t2Constraint.allowedTypes = t2AllowedEdges.data(); + t2Constraint.allowedTypeCount = static_cast(t2AllowedEdges.size()); + + std::vector typeConstraints{ t1Constraint, t2Constraint }; + kernelDescription.typeConstraints = typeConstraints.data(); + kernelDescription.typeConstraintCount = static_cast(typeConstraints.size()); + + MLOperatorAttributeNameValue alignedCornersAttributeValue; + alignedCornersAttributeValue.name = AttrName::AlignCorners; + alignedCornersAttributeValue.type = MLOperatorAttributeType::Int; + alignedCornersAttributeValue.valueCount = 1; + static const int64_t alignedCorners[] = { 0 }; + alignedCornersAttributeValue.ints = alignedCorners; + + MLOperatorAttributeNameValue modeAttributeValue; + modeAttributeValue.name = AttrName::Mode; + modeAttributeValue.type = MLOperatorAttributeType::String; + modeAttributeValue.valueCount = 1; + static const char* modes[] = { "bilinear" }; + modeAttributeValue.strings = modes; + + MLOperatorAttributeNameValue paddingModeAttributeValue; + paddingModeAttributeValue.name = AttrName::Mode; + paddingModeAttributeValue.type = MLOperatorAttributeType::String; + paddingModeAttributeValue.valueCount = 1; + static const char* paddingModes[] = { "zeros" }; + paddingModeAttributeValue.strings = paddingModes; + + std::vector attributeDefaultValues{ + alignedCornersAttributeValue, + modeAttributeValue, + paddingModeAttributeValue + }; + + kernelDescription.defaultAttributes = attributeDefaultValues.data(); + kernelDescription.defaultAttributeCount = static_cast(attributeDefaultValues.size()); + kernelDescription.options = MLOperatorKernelOptions::None; + kernelDescription.executionOptions = 0; + + auto shareInferrer = wil::MakeOrThrow(); + auto factory = wil::MakeOrThrow(); + + ComPtr registryPrivate; + ORT_THROW_IF_FAILED(registry->QueryInterface(IID_PPV_ARGS(®istryPrivate))); + + ORT_THROW_IF_FAILED(registryPrivate->RegisterOperatorKernel( + &kernelDescription, + factory.Get(), + shareInferrer.Get(), + nullptr, + false, // isInternalOperator + false, // alias + false, // supportsGraph + nullptr, + nullptr, + 0)); + + } +}; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp index 132b099aab..bbebb4a333 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp @@ -10,32 +10,20 @@ Abbreviations: B is batch_size, S is sequence_length, W is hidden_size M A B C // M, A, B, and C are Inputs | \ | / - Cast Gemm + | Gemm | / | \ | / | \ | / | \ | Slice Slice Slice - Identity | | | + | | | | | | | | | Identity Identity Identity // The identities are used to transpose NCHW -> NHCW while | | | | // keeping the GEMM strides as NCHW to better target metacommands | | | | - | ----- | - ----------- | | - \ | | - Gemm | - | | - | | - Softmax | - | / - | / - \ / - \ / - Gemm - | - ActivationLinear - | - Output // Final output + ----------------- MHA ----- + | + | + Output // Final output This kernel creates a DML_GRAPH, as mentioned above. For reference, refer to this Doc: @@ -49,51 +37,165 @@ class DmlOperatorAttention : public DmlOperator DmlOperatorAttention(const MLOperatorKernelCreationContext& kernelCreationContext) : DmlOperator(kernelCreationContext) { - ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 3); + enum DmlInputIndex : uint32_t + { + mhaQueryIndex, + mhaKeyIndex, + mhaValueIndex, + mhaStackedQueryKeyIndex, + mhaStackedKeyValueIndex, + mhaStackedQueryKeyValueIndex, + mhaBiasIndex, + mhaMaskIndex, + mhaRelativePositionBiasIndex, + mhaPastKeyIndex, + mhaPastValueIndex, + mhaInputCount, + }; + + enum InputIndex : uint32_t + { + inputIndex, + weightsIndex, + biasIndex, + maskIndex, + pastIndex, + relativePositionBiasIndex, + pastSequenceLengthIndex, + inputCount, + }; + + enum OutputIndex : uint32_t + { + outputIndex, + outputCount, + }; + + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 2); + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() >= 1); + + const uint32_t dmlInputIndex = inputIndex; + const uint32_t dmlWeightsIndex = weightsIndex; + const uint32_t dmlBiasIndex = biasIndex; + const uint32_t dmlMaskIndex = maskIndex; + const uint32_t dmlRelativePositionBiasIndex = relativePositionBiasIndex; + + const bool hasBias = kernelCreationContext.IsInputValid(biasIndex); + const bool hasMask = kernelCreationContext.IsInputValid(maskIndex); + const bool hasUnpaddedBounds = hasMask && kernelCreationContext.GetInputTensorDimensionCount(maskIndex) == 1; + const bool hasRelativePositionBias = kernelCreationContext.IsInputValid(relativePositionBiasIndex); + DmlOperator::Initialize(kernelCreationContext, std::nullopt, std::nullopt, std::nullopt, std::nullopt, 1); - std::vector inputTensorShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0); - std::vector weightTensorShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(1); - std::vector biasTensorShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(2); - std::vector maskIndexTensorShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(3); + const uint32_t numHeads = gsl::narrow_cast(kernelCreationContext.GetAttribute(AttrName::NumHeads)); + ML_CHECK_VALID_ARGUMENT(numHeads > 0); // to avoid process crash because of division by zero. + + auto inputTensorShape = m_inputTensorDescs[dmlInputIndex].GetSizes(); ML_CHECK_VALID_ARGUMENT(inputTensorShape.size() == 3); + + auto weightTensorShape = m_inputTensorDescs[dmlWeightsIndex].GetSizes(); ML_CHECK_VALID_ARGUMENT(weightTensorShape.size() == 2); - ML_CHECK_VALID_ARGUMENT(biasTensorShape.size() == 1); - ML_CHECK_VALID_ARGUMENT(weightTensorShape[1] == biasTensorShape[0]); - ML_CHECK_VALID_ARGUMENT(biasTensorShape[0] % 3 == 0); - ML_CHECK_VALID_ARGUMENT(inputTensorShape[2] == weightTensorShape[0]); - // TODO: fix Attention kernel when maskIndexTensorShape is 1 - // https://microsoft.visualstudio.com/OS/_workitems/edit/41893987 - ML_CHECK_VALID_ARGUMENT(maskIndexTensorShape.size() > 1 && maskIndexTensorShape.size() <= 4); + ML_CHECK_VALID_ARGUMENT(weightTensorShape[0] == inputTensorShape[2]); + + const auto qkvHiddenSizes = kernelCreationContext.GetOptionalAttributeVectorInt32(AttrName::QkvHiddenSizes); + if (hasBias) + { + auto biasTensorShape = m_inputTensorDescs[dmlBiasIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(biasTensorShape.size() == 1); + ML_CHECK_VALID_ARGUMENT(weightTensorShape[1] == biasTensorShape[0]); + + if (qkvHiddenSizes.empty()) + { + ML_CHECK_VALID_ARGUMENT(biasTensorShape[0] % 3 == 0); + } + } + + if (!qkvHiddenSizes.empty()) + { + ML_CHECK_VALID_ARGUMENT(qkvHiddenSizes.size() == 3); + ML_CHECK_VALID_ARGUMENT(qkvHiddenSizes[0] == qkvHiddenSizes[1]); + } + else + { + ML_CHECK_VALID_ARGUMENT(weightTensorShape[1] % 3 == 0); + } + + const uint32_t hiddenSize = qkvHiddenSizes.empty() ? weightTensorShape[1] / 3 : qkvHiddenSizes[0]; + const uint32_t vHiddenSize = qkvHiddenSizes.empty() ? weightTensorShape[1] / 3 : qkvHiddenSizes[2]; + const uint32_t headSize = hiddenSize / numHeads; + const uint32_t vHeadSize = vHiddenSize / numHeads; const uint32_t batchSize = inputTensorShape[0]; const uint32_t sequenceLength = inputTensorShape[1]; - const uint32_t hiddenSize = biasTensorShape[0] / 3; - const uint32_t numHeads = gsl::narrow_cast(kernelCreationContext.GetAttribute(AttrName::NumHeads)); - ML_CHECK_VALID_ARGUMENT(numHeads > 0); // to avoid process crash because of division by zero. - ML_CHECK_VALID_ARGUMENT(hiddenSize % numHeads == 0); - const uint32_t headSize = hiddenSize / numHeads; - uint32_t desiredWeightTensorShape[3] = {batchSize, weightTensorShape[0], 3 * hiddenSize}; - uint32_t desiredBiasTensorShape[3] = {batchSize, sequenceLength, 3 * hiddenSize}; - MLOperatorTensorDataType dataType = kernelCreationContext.GetInputEdgeDescription(0).tensorDataType; + uint32_t desiredWeightTensorShape[3] = {batchSize, weightTensorShape[0], hiddenSize + hiddenSize + vHiddenSize}; + MLOperatorTensorDataType dataType = kernelCreationContext.GetInputEdgeDescription(inputIndex).tensorDataType; - // overwrite weightTensorDesc - m_inputTensorDescs[1] = TensorDesc::ConstructBroadcastedTensorDesc(dataType, desiredWeightTensorShape, weightTensorShape); + m_inputTensorDescs[dmlWeightsIndex] = TensorDesc::ConstructBroadcastedTensorDesc(dataType, desiredWeightTensorShape, weightTensorShape); - // overwrite biasTensorDesc - m_inputTensorDescs[2] = TensorDesc::ConstructBroadcastedTensorDesc(dataType, desiredBiasTensorShape, biasTensorShape); + uint32_t desiredBiasTensorShape[3] = {batchSize, sequenceLength, hiddenSize + hiddenSize + vHiddenSize}; + if (hasBias) + { + auto biasTensorShape = m_inputTensorDescs[dmlBiasIndex].GetSizes(); + m_inputTensorDescs[dmlBiasIndex] = TensorDesc::ConstructBroadcastedTensorDesc(dataType, desiredBiasTensorShape, biasTensorShape); + } - // overwrite maskIndexTensorDesc - uint32_t maskIndexDimensionCount = gsl::narrow_cast(maskIndexTensorShape.size()); - maskIndexTensorShape.insert(maskIndexTensorShape.begin() + 1, 4 - maskIndexDimensionCount, 1); - uint32_t desiredMaskIndexShape[4] {batchSize, numHeads, sequenceLength, sequenceLength}; - MLOperatorTensorDataType maskTensorDataType = kernelCreationContext.GetInputEdgeDescription(3).tensorDataType; - m_inputTensorDescs[3] = TensorDesc::ConstructBroadcastedTensorDesc(maskTensorDataType, desiredMaskIndexShape, maskIndexTensorShape); + MLOperatorTensorDataType maskTensorDataType = MLOperatorTensorDataType::Undefined; + bool hasMaxSequenceMask = false; + DML_MULTIHEAD_ATTENTION_MASK_TYPE maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE; + if (hasMask) + { + if (hasUnpaddedBounds) + { + auto unpaddedKeyBoundsShape = m_inputTensorDescs[dmlMaskIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(unpaddedKeyBoundsShape.size() == 1); + + const uint32_t batchGroupCount = unpaddedKeyBoundsShape[0] / batchSize; + ML_CHECK_VALID_ARGUMENT(batchGroupCount == 1 || batchGroupCount == 2); + + uint32_t desiredShape[2] = {batchGroupCount, batchSize}; + m_inputTensorDescs[dmlMaskIndex] = TensorDesc( + m_inputTensorDescs[dmlMaskIndex].GetDmlDataType(), + desiredShape); + + maskType = batchGroupCount == 1 + ? DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH + : DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START; + } + else + { + auto maskIndexTensorShape = m_inputTensorDescs[dmlMaskIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(maskIndexTensorShape.size() > 1 && maskIndexTensorShape.size() <= 4); + + maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN; + std::vector reshapedMaskIndexTensorShape(maskIndexTensorShape.begin(), maskIndexTensorShape.end()); + if (maskIndexTensorShape.size() == 4 && maskIndexTensorShape[2] != sequenceLength) + { + hasMaxSequenceMask = true; + ML_CHECK_VALID_ARGUMENT(maskIndexTensorShape[2] == maskIndexTensorShape[3]); + const uint32_t maxSequenceLength = maskIndexTensorShape[2]; + uint32_t desiredMaskIndexShape[4] {batchSize, numHeads, maxSequenceLength, maxSequenceLength}; + maskTensorDataType = kernelCreationContext.GetInputEdgeDescription(maskIndex).tensorDataType; + m_inputTensorDescs[dmlMaskIndex] = TensorDesc::ConstructBroadcastedTensorDesc(maskTensorDataType, desiredMaskIndexShape, reshapedMaskIndexTensorShape); + } + else + { + uint32_t maskIndexDimensionCount = gsl::narrow_cast(maskIndexTensorShape.size()); + reshapedMaskIndexTensorShape.insert(reshapedMaskIndexTensorShape.begin() + 1, 4 - maskIndexDimensionCount, 1); + uint32_t desiredMaskIndexShape[4] {batchSize, numHeads, sequenceLength, sequenceLength}; + maskTensorDataType = kernelCreationContext.GetInputEdgeDescription(maskIndex).tensorDataType; + m_inputTensorDescs[dmlMaskIndex] = TensorDesc::ConstructBroadcastedTensorDesc(maskTensorDataType, desiredMaskIndexShape, reshapedMaskIndexTensorShape); + } + } + } - // overwrite output tensor desc - uint32_t outputTensorShape[4] = {batchSize, sequenceLength, numHeads, headSize}; - uint32_t outputTensorStrides[4] = {sequenceLength * numHeads * headSize, headSize, headSize * sequenceLength, 1}; - m_outputTensorDescs[0] = TensorDesc(GetDmlDataTypeFromMlDataType(dataType), outputTensorShape, outputTensorStrides, 0); + if (hasRelativePositionBias) + { + auto relativePositionBiasTensorShape = m_inputTensorDescs[dmlRelativePositionBiasIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(relativePositionBiasTensorShape.size() == 4); + ML_CHECK_VALID_ARGUMENT(relativePositionBiasTensorShape[0] == inputTensorShape[0]); + ML_CHECK_VALID_ARGUMENT(relativePositionBiasTensorShape[1] == numHeads); + ML_CHECK_VALID_ARGUMENT(relativePositionBiasTensorShape[2] == inputTensorShape[1]); + } TensorDesc firstGemmOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, desiredBiasTensorShape); DML_TENSOR_DESC namedFirstGemmOutputTensorDesc = firstGemmOutputTensorDesc.GetDmlDesc(); @@ -101,345 +203,335 @@ class DmlOperatorAttention : public DmlOperator std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); - DML_GEMM_OPERATOR_DESC xWeightOperatorDesc = {}; - xWeightOperatorDesc.ATensor = &inputDescs[0]; - xWeightOperatorDesc.BTensor = &inputDescs[1]; - xWeightOperatorDesc.CTensor = &inputDescs[2]; - xWeightOperatorDesc.OutputTensor = &namedFirstGemmOutputTensorDesc; - xWeightOperatorDesc.TransA = DML_MATRIX_TRANSFORM_NONE; - xWeightOperatorDesc.TransB = DML_MATRIX_TRANSFORM_NONE; - xWeightOperatorDesc.Alpha = 1.0f; - xWeightOperatorDesc.Beta = 1.0f; - xWeightOperatorDesc.FusedActivation = nullptr; - const DML_OPERATOR_DESC xWeightDesc {DML_OPERATOR_GEMM, &xWeightOperatorDesc}; - - - std::array querySlicedTensorShape {batchSize, sequenceLength, hiddenSize}; - TensorDesc querySlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, querySlicedTensorShape); - DML_TENSOR_DESC namedQuerySlicedInputTensorDesc = querySlicedInputTensorDesc.GetDmlDesc(); - - std::array querySliceOffset = {0, 0, 0}; - std::array keySliceOffset = {0, 0, hiddenSize}; + DML_GEMM_OPERATOR_DESC gemmOperatorDesc = {}; + gemmOperatorDesc.ATensor = &inputDescs[0]; + gemmOperatorDesc.BTensor = &inputDescs[1]; + + if (hasBias) + { + gemmOperatorDesc.CTensor = &inputDescs[2]; + } + + gemmOperatorDesc.OutputTensor = &namedFirstGemmOutputTensorDesc; + gemmOperatorDesc.TransA = DML_MATRIX_TRANSFORM_NONE; + gemmOperatorDesc.TransB = DML_MATRIX_TRANSFORM_NONE; + gemmOperatorDesc.Alpha = 1.0f; + gemmOperatorDesc.Beta = 1.0f; + gemmOperatorDesc.FusedActivation = nullptr; + const DML_OPERATOR_DESC gemmDesc {DML_OPERATOR_GEMM, &gemmOperatorDesc}; + + std::array queryKeySlicedTensorShape {batchSize, sequenceLength, hiddenSize + hiddenSize}; + TensorDesc queryKeySlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, queryKeySlicedTensorShape); + DML_TENSOR_DESC namedQueryKeySlicedInputTensorDesc = queryKeySlicedInputTensorDesc.GetDmlDesc(); + + std::array valueSlicedTensorShape {batchSize, sequenceLength, vHiddenSize}; + TensorDesc valueSlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, valueSlicedTensorShape); + DML_TENSOR_DESC namedValueSlicedInputTensorDesc = valueSlicedInputTensorDesc.GetDmlDesc(); + + // Transpose slice QK from [batchSize, sequenceLength, 2, numHeads, headSize] to [batchSize, sequenceLength, numHeads, 2, headSize] + std::array queryKeyTransposedTensorShape {batchSize, sequenceLength, numHeads, 2, headSize}; + std::array queryKeyTransposedStrides { + sequenceLength * numHeads * 2 * headSize, + numHeads * 2 * headSize, + headSize, + numHeads * headSize, + 1, + }; + + TensorDesc queryKeyTransposedInputTensorDesc = TensorDesc( + m_inputTensorDescs[dmlInputIndex].GetDmlDataType(), + queryKeyTransposedTensorShape, + queryKeyTransposedStrides); + DML_TENSOR_DESC namedQueryKeyTransposedInputTensorDesc = queryKeyTransposedInputTensorDesc.GetDmlDesc(); + + TensorDesc queryKeyTransposedOutputTensorDesc = TensorDesc( + m_inputTensorDescs[dmlInputIndex].GetDmlDataType(), + queryKeyTransposedTensorShape); + DML_TENSOR_DESC namedQueryKeyTransposedOutputTensorDesc = queryKeyTransposedOutputTensorDesc.GetDmlDesc(); + + // Transpose QKV from [batchSize, sequenceLength, 3, numHeads, headSize] to [batchSize, sequenceLength, numHeads, 3, headSize] + std::array queryKeyValueTransposedTensorShape {batchSize, sequenceLength, numHeads, 3, headSize}; + std::array queryKeyValueTransposedStrides { + sequenceLength * numHeads * 3 * headSize, + numHeads * 3 * headSize, + headSize, + numHeads * headSize, + 1, + }; + + TensorDesc queryKeyValueTransposedInputTensorDesc = TensorDesc( + m_inputTensorDescs[dmlInputIndex].GetDmlDataType(), + queryKeyValueTransposedTensorShape, + queryKeyValueTransposedStrides); + DML_TENSOR_DESC namedQueryKeyValueTransposedInputTensorDesc = queryKeyValueTransposedInputTensorDesc.GetDmlDesc(); + + TensorDesc queryKeyValueTransposedOutputTensorDesc = TensorDesc( + m_inputTensorDescs[dmlInputIndex].GetDmlDataType(), + queryKeyValueTransposedTensorShape); + DML_TENSOR_DESC namedQueryKeyValueTransposedOutputTensorDesc = queryKeyValueTransposedOutputTensorDesc.GetDmlDesc(); + + std::array queryKeySliceOffset = {0, 0, 0}; + std::array queryKeySliceSize = {batchSize, sequenceLength, hiddenSize + hiddenSize}; + std::array queryKeySliceStrides = {1, 1, 1}; + std::array valueSliceOffset = {0, 0, 2 * hiddenSize}; - std::array sliceSize = {batchSize, sequenceLength, hiddenSize}; - std::array strides = {1, 1, 1}; - DML_SLICE1_OPERATOR_DESC querySlicedOperatorDesc = {}; - querySlicedOperatorDesc.InputTensor = &namedFirstGemmOutputTensorDesc; - querySlicedOperatorDesc.OutputTensor = &namedQuerySlicedInputTensorDesc; - querySlicedOperatorDesc.DimensionCount = gsl::narrow_cast(querySlicedTensorShape.size()); - querySlicedOperatorDesc.InputWindowOffsets = querySliceOffset.data(); - querySlicedOperatorDesc.InputWindowSizes = sliceSize.data(); - querySlicedOperatorDesc.InputWindowStrides = strides.data(); - const DML_OPERATOR_DESC querySlicedDesc = { DML_OPERATOR_SLICE1, &querySlicedOperatorDesc }; - - DML_SLICE1_OPERATOR_DESC keySlicedOperatorDesc = {}; - keySlicedOperatorDesc.InputTensor = &namedFirstGemmOutputTensorDesc; - keySlicedOperatorDesc.OutputTensor = &namedQuerySlicedInputTensorDesc; - keySlicedOperatorDesc.DimensionCount = gsl::narrow_cast(querySlicedTensorShape.size()); - keySlicedOperatorDesc.InputWindowOffsets = keySliceOffset.data(); - keySlicedOperatorDesc.InputWindowSizes = sliceSize.data(); - keySlicedOperatorDesc.InputWindowStrides = strides.data(); - const DML_OPERATOR_DESC keySlicedDesc = { DML_OPERATOR_SLICE1, &keySlicedOperatorDesc }; + std::array valueSliceSize = {batchSize, sequenceLength, vHiddenSize}; + std::array valueSliceStrides = {1, 1, 1}; + const bool hasSlicedValue = hiddenSize != vHiddenSize; + // We need to slice the value tensor when its hidden size is different from the query and key + DML_SLICE1_OPERATOR_DESC queryKeySlicedOperatorDesc = {}; DML_SLICE1_OPERATOR_DESC valueSlicedOperatorDesc = {}; - valueSlicedOperatorDesc.InputTensor = &namedFirstGemmOutputTensorDesc; - valueSlicedOperatorDesc.OutputTensor = &namedQuerySlicedInputTensorDesc; - valueSlicedOperatorDesc.DimensionCount = gsl::narrow_cast(querySlicedTensorShape.size()); - valueSlicedOperatorDesc.InputWindowOffsets = valueSliceOffset.data(); - valueSlicedOperatorDesc.InputWindowSizes = sliceSize.data(); - valueSlicedOperatorDesc.InputWindowStrides = strides.data(); + DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC transposeOperatorDesc = {}; + if (hasSlicedValue) + { + queryKeySlicedOperatorDesc.InputTensor = &namedFirstGemmOutputTensorDesc; + queryKeySlicedOperatorDesc.OutputTensor = &namedQueryKeySlicedInputTensorDesc; + queryKeySlicedOperatorDesc.DimensionCount = gsl::narrow_cast(queryKeySlicedTensorShape.size()); + queryKeySlicedOperatorDesc.InputWindowOffsets = queryKeySliceOffset.data(); + queryKeySlicedOperatorDesc.InputWindowSizes = queryKeySliceSize.data(); + queryKeySlicedOperatorDesc.InputWindowStrides = queryKeySliceStrides.data(); + + valueSlicedOperatorDesc.InputTensor = &namedFirstGemmOutputTensorDesc; + valueSlicedOperatorDesc.OutputTensor = &namedValueSlicedInputTensorDesc; + valueSlicedOperatorDesc.DimensionCount = gsl::narrow_cast(valueSlicedTensorShape.size()); + valueSlicedOperatorDesc.InputWindowOffsets = valueSliceOffset.data(); + valueSlicedOperatorDesc.InputWindowSizes = valueSliceSize.data(); + valueSlicedOperatorDesc.InputWindowStrides = valueSliceStrides.data(); + + transposeOperatorDesc.InputTensor = &namedQueryKeyTransposedInputTensorDesc; + transposeOperatorDesc.OutputTensor = &namedQueryKeyTransposedOutputTensorDesc; + } + else + { + // When Q/K/V all have the same hidden size, we just have to transpose it before sending it to MHA + transposeOperatorDesc.InputTensor = &namedQueryKeyValueTransposedInputTensorDesc; + transposeOperatorDesc.OutputTensor = &namedQueryKeyValueTransposedOutputTensorDesc; + } + const DML_OPERATOR_DESC queryKeySlicedDesc = { DML_OPERATOR_SLICE1, &queryKeySlicedOperatorDesc}; const DML_OPERATOR_DESC valueSlicedDesc = { DML_OPERATOR_SLICE1, &valueSlicedOperatorDesc}; + const DML_OPERATOR_DESC transposedDesc = { DML_OPERATOR_ELEMENT_WISE_IDENTITY, &transposeOperatorDesc}; + + std::array maskSliceOutputShape {batchSize, numHeads, sequenceLength, sequenceLength}; + std::array maskSliceStrides = {1, 1, 1, 1}; + std::array maskSliceOffsets = {0, 0, 0, 0}; + TensorDesc maskSliceOutputTensorDesc; + DML_TENSOR_DESC namedMaskSliceOutputTensorDesc; - TensorDesc castedMaskIndexTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, desiredMaskIndexShape); - DML_TENSOR_DESC namedCastedMaskIndexTensorDesc = castedMaskIndexTensorDesc.GetDmlDesc(); - - DML_CAST_OPERATOR_DESC castMaskIndexOperatorDesc = {}; - castMaskIndexOperatorDesc.InputTensor = &inputDescs[3]; - castMaskIndexOperatorDesc.OutputTensor = &namedCastedMaskIndexTensorDesc; - const DML_OPERATOR_DESC castMaskIndexDesc = {DML_OPERATOR_CAST, &castMaskIndexOperatorDesc}; - - // The attention fusion in ORT expects this to be number to -10000. - // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/optimizer/attention_fusion_helper.h#L604 - // The decomposed Attention performs: (M - 1.0) * -10000.0, where M is the 4th input of the Attention node. - // Above equation can be written as (M * -1000) + 10000.0 - DML_SCALE_BIAS scaleBias = {}; - scaleBias.Scale = -10000.0f; - scaleBias.Bias = 10000.0f; - DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC maskOperatorDesc = {}; - maskOperatorDesc.InputTensor = &namedCastedMaskIndexTensorDesc; - maskOperatorDesc.OutputTensor = &namedCastedMaskIndexTensorDesc; - maskOperatorDesc.ScaleBias = &scaleBias; - const DML_OPERATOR_DESC maskDesc = {DML_OPERATOR_ELEMENT_WISE_IDENTITY, &maskOperatorDesc}; - - // original reshaped shape: [batchSize, seqenceLength, numHeads, headSize] - // transposed shape to [0, 2, 1, 3] -> [batchSize, numHeads, sequenceLength, headSize] - uint32_t reshapedTransposedQueryTensorShape[4] = {batchSize, numHeads, sequenceLength, headSize}; - uint32_t reshapedTransposedQueryTensorStrides[4] = {sequenceLength * numHeads * headSize, headSize, numHeads * headSize, 1}; - TensorDesc reshapedTransposedQueryTensorDesc = TensorDesc( - GetDmlDataTypeFromMlDataType(dataType), - reshapedTransposedQueryTensorShape, - reshapedTransposedQueryTensorStrides); - DML_TENSOR_DESC namedReshapedTransposedQueryTensorDesc = reshapedTransposedQueryTensorDesc.GetDmlDesc(); - - TensorDesc reshapedTransposedQueryOutputTensorDesc = TensorDesc( - GetDmlDataTypeFromMlDataType(dataType), - reshapedTransposedQueryTensorShape); - DML_TENSOR_DESC namedReshapedTransposedQueryOutputTensorDesc = reshapedTransposedQueryOutputTensorDesc.GetDmlDesc(); - - DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC transposedQueryOperatorDesc{}; - transposedQueryOperatorDesc.InputTensor = &namedReshapedTransposedQueryTensorDesc; - transposedQueryOperatorDesc.OutputTensor = &namedReshapedTransposedQueryOutputTensorDesc; - const DML_OPERATOR_DESC transposedQueryDesc {DML_OPERATOR_ELEMENT_WISE_IDENTITY, &transposedQueryOperatorDesc}; - - uint32_t reshapedTransposedKeyTensorShape[4] = {batchSize, numHeads, headSize, sequenceLength}; - uint32_t reshapedTransposedKeyTensorStrides[4] = {sequenceLength * numHeads * headSize, headSize, 1, numHeads * headSize}; - TensorDesc reshapedTransposedKeyTensorDesc = TensorDesc( - GetDmlDataTypeFromMlDataType(dataType), - reshapedTransposedKeyTensorShape, - reshapedTransposedKeyTensorStrides); - DML_TENSOR_DESC namedReshapedTransposedKeyTensorDesc = reshapedTransposedKeyTensorDesc.GetDmlDesc(); - - TensorDesc reshapedTransposedKeyOutputTensorDesc = TensorDesc( - GetDmlDataTypeFromMlDataType(dataType), - reshapedTransposedKeyTensorShape); - DML_TENSOR_DESC namedReshapedTransposedKeyOutputTensorDesc = reshapedTransposedKeyOutputTensorDesc.GetDmlDesc(); - - DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC transposedKeyOperatorDesc{}; - transposedKeyOperatorDesc.InputTensor = &namedReshapedTransposedKeyTensorDesc; - transposedKeyOperatorDesc.OutputTensor = &namedReshapedTransposedKeyOutputTensorDesc; - const DML_OPERATOR_DESC transposedKeyDesc {DML_OPERATOR_ELEMENT_WISE_IDENTITY, &transposedKeyOperatorDesc}; - - uint32_t queryKeyTensorShape[4] = {batchSize, numHeads, sequenceLength, sequenceLength}; - TensorDesc queryKeyTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, queryKeyTensorShape); - DML_TENSOR_DESC namedQueryKeyTensorDesc = queryKeyTensorDesc.GetDmlDesc(); - - float alpha = static_cast(1 / sqrt(headSize)); - DML_GEMM_OPERATOR_DESC attentionScoreOperatorDesc = {}; - attentionScoreOperatorDesc.ATensor = &namedReshapedTransposedQueryOutputTensorDesc; - attentionScoreOperatorDesc.BTensor = &namedReshapedTransposedKeyOutputTensorDesc; - attentionScoreOperatorDesc.CTensor = &namedCastedMaskIndexTensorDesc; - attentionScoreOperatorDesc.OutputTensor = &namedQueryKeyTensorDesc; - attentionScoreOperatorDesc.TransA = DML_MATRIX_TRANSFORM_NONE; - attentionScoreOperatorDesc.TransB = DML_MATRIX_TRANSFORM_NONE; - attentionScoreOperatorDesc.Alpha = alpha; - attentionScoreOperatorDesc.Beta = 0.0f; - attentionScoreOperatorDesc.FusedActivation = nullptr; - const DML_OPERATOR_DESC attentionScoreDesc {DML_OPERATOR_GEMM, &attentionScoreOperatorDesc}; - - std::array axes = {3}; - DML_ACTIVATION_SOFTMAX1_OPERATOR_DESC softmaxOperatorDesc = {}; - softmaxOperatorDesc.InputTensor = &namedQueryKeyTensorDesc; - softmaxOperatorDesc.OutputTensor = &namedQueryKeyTensorDesc; - softmaxOperatorDesc.AxisCount = gsl::narrow_cast(axes.size()); - softmaxOperatorDesc.Axes = axes.data(); - const DML_OPERATOR_DESC softmaxDesc = {DML_OPERATOR_ACTIVATION_SOFTMAX1, &softmaxOperatorDesc}; - - uint32_t reshapedTransposedOutputTensorShape[4] {batchSize, numHeads, sequenceLength, headSize}; - uint32_t reshapedTransposedOutputTensorStrides[4] {sequenceLength * numHeads * headSize, headSize * sequenceLength, headSize, 1}; - TensorDesc reshapedTransposedOutputTensorDesc = TensorDesc( - GetDmlDataTypeFromMlDataType(dataType), - reshapedTransposedOutputTensorShape, - reshapedTransposedOutputTensorStrides, - 0 // guaranteedBaseOffsetAlignment - ); - DML_TENSOR_DESC namedReshapedTransposedOutputTensorDesc = reshapedTransposedOutputTensorDesc.GetDmlDesc(); - - DML_GEMM_OPERATOR_DESC attentionWeightOperatorDesc = {}; - attentionWeightOperatorDesc.ATensor = &namedQueryKeyTensorDesc; - attentionWeightOperatorDesc.BTensor = &namedReshapedTransposedQueryOutputTensorDesc; - attentionWeightOperatorDesc.CTensor = nullptr; - attentionWeightOperatorDesc.OutputTensor = &namedReshapedTransposedOutputTensorDesc; - attentionWeightOperatorDesc.TransA = DML_MATRIX_TRANSFORM_NONE; - attentionWeightOperatorDesc.TransB = DML_MATRIX_TRANSFORM_NONE; - attentionWeightOperatorDesc.Alpha = 1.0f; - attentionWeightOperatorDesc.Beta = 0.0f; - attentionWeightOperatorDesc.FusedActivation = nullptr; - const DML_OPERATOR_DESC attentionWeightDesc {DML_OPERATOR_GEMM, &attentionWeightOperatorDesc}; - - TensorDesc transposedOutputTensorDesc = TensorDesc( - m_outputTensorDescs[0].GetDmlDataType(), - m_outputTensorDescs[0].GetSizes(), - std::nullopt, - 0 // guaranteedBaseOffsetAlignment - ); - DML_TENSOR_DESC namedTransposedOutputTensorDesc = transposedOutputTensorDesc.GetDmlDesc(); - - DML_ACTIVATION_LINEAR_OPERATOR_DESC outputOperatorDesc = {}; - outputOperatorDesc.Alpha = 1.0f; - outputOperatorDesc.Beta = 0.0f; - outputOperatorDesc.InputTensor = &outputDescs[0]; - outputOperatorDesc.OutputTensor = &namedTransposedOutputTensorDesc; - const DML_OPERATOR_DESC outputDesc {DML_OPERATOR_ACTIVATION_LINEAR, &outputOperatorDesc}; - - enum NodeIndex : uint32_t + DML_SLICE1_OPERATOR_DESC maskSlicedOperatorDesc = {}; + if (hasMaxSequenceMask) { - xWeight, - querySlice, - keySlice, - valueSlice, - queryTranspose, - keyTranspose, - attentionScore, - softmax, - valueTranspose, - attentionWeight, - castMaskIndex, - mask, - output, - count, - }; + maskSliceOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(maskTensorDataType, maskSliceOutputShape); + namedMaskSliceOutputTensorDesc = maskSliceOutputTensorDesc.GetDmlDesc(); + maskSlicedOperatorDesc.InputTensor = &inputDescs[dmlMaskIndex]; + maskSlicedOperatorDesc.OutputTensor = &namedMaskSliceOutputTensorDesc; + maskSlicedOperatorDesc.DimensionCount = gsl::narrow_cast(maskSliceOutputShape.size()); + maskSlicedOperatorDesc.InputWindowOffsets = maskSliceOffsets.data(); + maskSlicedOperatorDesc.InputWindowSizes = maskSliceOutputShape.data(); + maskSlicedOperatorDesc.InputWindowStrides = maskSliceStrides.data(); + } + const DML_OPERATOR_DESC maskSlicedDesc = { DML_OPERATOR_SLICE1, &maskSlicedOperatorDesc}; - MLOperatorGraphDesc operatorGraphDesc = {}; - std::array opDescs = { - &xWeightDesc, - &querySlicedDesc, - &keySlicedDesc, - &valueSlicedDesc, - &transposedQueryDesc, - &transposedKeyDesc, - &attentionScoreDesc, - &softmaxDesc, - &transposedQueryDesc, - &attentionWeightDesc, - &castMaskIndexDesc, - &maskDesc, - &outputDesc - }; - operatorGraphDesc.nodeCount = NodeIndex::count; - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + DML_MULTIHEAD_ATTENTION_OPERATOR_DESC mhaOperatorDesc = {}; + mhaOperatorDesc.ValueTensor = hasSlicedValue ? &namedValueSlicedInputTensorDesc : nullptr; + mhaOperatorDesc.StackedQueryKeyTensor = hasSlicedValue ? &namedQueryKeyTransposedOutputTensorDesc : nullptr; + mhaOperatorDesc.StackedQueryKeyValueTensor = hasSlicedValue ? nullptr : &namedQueryKeyValueTransposedOutputTensorDesc; - // set input edges - std::pair nodeToNodeInputIndex[4] { - {NodeIndex::xWeight, 0}, - {NodeIndex::xWeight, 1}, - {NodeIndex::xWeight, 2}, - {NodeIndex::castMaskIndex, 0} - }; - std::array inputEdges; - for (uint32_t inputIndex = 0; inputIndex < inputEdges.size(); inputIndex++) + if (hasMaxSequenceMask) { - DML_INPUT_GRAPH_EDGE_DESC inputEdge = {}; - inputEdge.GraphInputIndex = inputIndex; - inputEdge.ToNodeIndex = nodeToNodeInputIndex[inputIndex].first; - inputEdge.ToNodeInputIndex = nodeToNodeInputIndex[inputIndex].second; - inputEdges[inputIndex] = inputEdge; + mhaOperatorDesc.MaskTensor = &namedMaskSliceOutputTensorDesc; } - operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); - operatorGraphDesc.inputEdges = inputEdges.data(); + else + { + mhaOperatorDesc.MaskTensor = hasMask ? &inputDescs[dmlMaskIndex] : nullptr; + } + + mhaOperatorDesc.RelativePositionBiasTensor = hasRelativePositionBias ? &inputDescs[dmlRelativePositionBiasIndex] : nullptr; + mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex]; + mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute(AttrName::Scale, gsl::narrow_cast(1.0f / std::sqrt(headSize))); + mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute(AttrName::MaskFilterValue, -10'000.0f); + mhaOperatorDesc.HeadCount = numHeads; + mhaOperatorDesc.MaskType = maskType; + const DML_OPERATOR_DESC mhaDesc = { DML_OPERATOR_MULTIHEAD_ATTENTION, &mhaOperatorDesc }; - // set intermediate edges + // Construct the graph + std::vector inputEdges; std::vector intermediateEdges; + std::vector outputEdges; + + std::vector opDescs = { + &gemmDesc, + &mhaDesc, + }; + + uint32_t currentNodeIndex = 0; + const uint32_t gemmNodeIndex = currentNodeIndex++; + const uint32_t mhaNodeIndex = currentNodeIndex++; + + uint32_t valueSliceNodeIndex = 0; + uint32_t queryKeySliceNodeIndex = 0; + uint32_t queryKeyTransposedNodeIndex = 0; + uint32_t queryKeyValueTransposedNodeIndex = 0; + if (hasSlicedValue) + { + opDescs.push_back(&queryKeySlicedDesc); + queryKeySliceNodeIndex = currentNodeIndex++; + + opDescs.push_back(&valueSlicedDesc); + valueSliceNodeIndex = currentNodeIndex++; + + opDescs.push_back(&transposedDesc); + queryKeyTransposedNodeIndex = currentNodeIndex++; + } + else + { + opDescs.push_back(&transposedDesc); + queryKeyValueTransposedNodeIndex = currentNodeIndex++; + } + + uint32_t maskSliceNodeIndex = 0; + if (hasMaxSequenceMask) + { + opDescs.push_back(&maskSlicedDesc); + maskSliceNodeIndex = currentNodeIndex++; + } + + DML_INPUT_GRAPH_EDGE_DESC inputToGemmEdge = {}; + inputToGemmEdge.GraphInputIndex = dmlInputIndex; + inputToGemmEdge.ToNodeIndex = gemmNodeIndex; + inputToGemmEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputToGemmEdge); + + DML_INPUT_GRAPH_EDGE_DESC weightToGemmEdge = {}; + weightToGemmEdge.GraphInputIndex = dmlWeightsIndex; + weightToGemmEdge.ToNodeIndex = gemmNodeIndex; + weightToGemmEdge.ToNodeInputIndex = 1; + inputEdges.push_back(weightToGemmEdge); + + if (hasBias) + { + DML_INPUT_GRAPH_EDGE_DESC biasToGemmEdge = {}; + biasToGemmEdge.GraphInputIndex = dmlBiasIndex; + biasToGemmEdge.ToNodeIndex = gemmNodeIndex; + biasToGemmEdge.ToNodeInputIndex = 2; + inputEdges.push_back(biasToGemmEdge); + } + + if (hasMask) + { + if (hasUnpaddedBounds) + { + DML_INPUT_GRAPH_EDGE_DESC maskToMhaEdge = {}; + maskToMhaEdge.GraphInputIndex = dmlMaskIndex; + maskToMhaEdge.ToNodeIndex = mhaNodeIndex; + maskToMhaEdge.ToNodeInputIndex = mhaMaskIndex; + inputEdges.push_back(maskToMhaEdge); + } + else if (hasMaxSequenceMask) + { + DML_INPUT_GRAPH_EDGE_DESC maskToMaskSliceEdge = {}; + maskToMaskSliceEdge.GraphInputIndex = dmlMaskIndex; + maskToMaskSliceEdge.ToNodeIndex = maskSliceNodeIndex; + maskToMaskSliceEdge.ToNodeInputIndex = 0; + inputEdges.push_back(maskToMaskSliceEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC maskSliceToMhaEdge = {}; + maskSliceToMhaEdge.FromNodeIndex = maskSliceNodeIndex; + maskSliceToMhaEdge.FromNodeOutputIndex = 0; + maskSliceToMhaEdge.ToNodeIndex = mhaNodeIndex; + maskSliceToMhaEdge.ToNodeInputIndex = mhaMaskIndex; + intermediateEdges.push_back(maskSliceToMhaEdge); + } + else + { + DML_INPUT_GRAPH_EDGE_DESC maskToMhaEdge = {}; + maskToMhaEdge.GraphInputIndex = dmlMaskIndex; + maskToMhaEdge.ToNodeIndex = mhaNodeIndex; + maskToMhaEdge.ToNodeInputIndex = mhaMaskIndex; + inputEdges.push_back(maskToMhaEdge); + } + } + + if (hasRelativePositionBias) + { + DML_INPUT_GRAPH_EDGE_DESC relativePositionBiasToMhaEdge = {}; + relativePositionBiasToMhaEdge.GraphInputIndex = dmlRelativePositionBiasIndex; + relativePositionBiasToMhaEdge.ToNodeIndex = mhaNodeIndex; + relativePositionBiasToMhaEdge.ToNodeInputIndex = mhaRelativePositionBiasIndex; + inputEdges.push_back(relativePositionBiasToMhaEdge); + } + + if (hasSlicedValue) + { + // We need to slice QK and V, and transpose QK + DML_INTERMEDIATE_GRAPH_EDGE_DESC gemmToQueryKeySliceEdge = {}; + gemmToQueryKeySliceEdge.FromNodeIndex = gemmNodeIndex; + gemmToQueryKeySliceEdge.FromNodeOutputIndex = 0; + gemmToQueryKeySliceEdge.ToNodeIndex = queryKeySliceNodeIndex; + gemmToQueryKeySliceEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(gemmToQueryKeySliceEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC queryKeySliceToTransposeEdge = {}; + queryKeySliceToTransposeEdge.FromNodeIndex = queryKeySliceNodeIndex; + queryKeySliceToTransposeEdge.FromNodeOutputIndex = 0; + queryKeySliceToTransposeEdge.ToNodeIndex = queryKeyTransposedNodeIndex; + queryKeySliceToTransposeEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(queryKeySliceToTransposeEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC queryKeyTransposedToMhaEdge = {}; + queryKeyTransposedToMhaEdge.FromNodeIndex = queryKeyTransposedNodeIndex; + queryKeyTransposedToMhaEdge.FromNodeOutputIndex = 0; + queryKeyTransposedToMhaEdge.ToNodeIndex = mhaNodeIndex; + queryKeyTransposedToMhaEdge.ToNodeInputIndex = mhaStackedQueryKeyIndex; + intermediateEdges.push_back(queryKeyTransposedToMhaEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC gemmToValueSliceEdge = {}; + gemmToValueSliceEdge.FromNodeIndex = gemmNodeIndex; + gemmToValueSliceEdge.FromNodeOutputIndex = 0; + gemmToValueSliceEdge.ToNodeIndex = valueSliceNodeIndex; + gemmToValueSliceEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(gemmToValueSliceEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC valueSliceToMhaEdge = {}; + valueSliceToMhaEdge.FromNodeIndex = valueSliceNodeIndex; + valueSliceToMhaEdge.FromNodeOutputIndex = 0; + valueSliceToMhaEdge.ToNodeIndex = mhaNodeIndex; + valueSliceToMhaEdge.ToNodeInputIndex = mhaValueIndex; + intermediateEdges.push_back(valueSliceToMhaEdge); + } + else + { + DML_INTERMEDIATE_GRAPH_EDGE_DESC gemmToQueryKeyValueTransposeEdge = {}; + gemmToQueryKeyValueTransposeEdge.FromNodeIndex = gemmNodeIndex; + gemmToQueryKeyValueTransposeEdge.FromNodeOutputIndex = 0; + gemmToQueryKeyValueTransposeEdge.ToNodeIndex = queryKeyValueTransposedNodeIndex; + gemmToQueryKeyValueTransposeEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(gemmToQueryKeyValueTransposeEdge); + + // All we need to do here is transpose the stacked QKV tensor into something DML supports + DML_INTERMEDIATE_GRAPH_EDGE_DESC queryKeyValueTransposedToMhaEdge = {}; + queryKeyValueTransposedToMhaEdge.FromNodeIndex = queryKeyValueTransposedNodeIndex; + queryKeyValueTransposedToMhaEdge.FromNodeOutputIndex = 0; + queryKeyValueTransposedToMhaEdge.ToNodeIndex = mhaNodeIndex; + queryKeyValueTransposedToMhaEdge.ToNodeInputIndex = mhaStackedQueryKeyValueIndex; + intermediateEdges.push_back(queryKeyValueTransposedToMhaEdge); + } - DML_INTERMEDIATE_GRAPH_EDGE_DESC gemmToQuerySliceEdge = {}; - gemmToQuerySliceEdge.FromNodeIndex = NodeIndex::xWeight; - gemmToQuerySliceEdge.FromNodeOutputIndex = 0; - gemmToQuerySliceEdge.ToNodeIndex = NodeIndex::querySlice; - gemmToQuerySliceEdge.ToNodeInputIndex = 0; - intermediateEdges.push_back(gemmToQuerySliceEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC gemmToKeySliceEdge = {}; - gemmToKeySliceEdge.FromNodeIndex = NodeIndex::xWeight; - gemmToKeySliceEdge.FromNodeOutputIndex = 0; - gemmToKeySliceEdge.ToNodeIndex = NodeIndex::keySlice; - gemmToKeySliceEdge.ToNodeInputIndex = 0; - intermediateEdges.push_back(gemmToKeySliceEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC gemmToValueSliceEdge = {}; - gemmToValueSliceEdge.FromNodeIndex = NodeIndex::xWeight; - gemmToValueSliceEdge.FromNodeOutputIndex = 0; - gemmToValueSliceEdge.ToNodeIndex = NodeIndex::valueSlice; - gemmToValueSliceEdge.ToNodeInputIndex = 0; - intermediateEdges.push_back(gemmToValueSliceEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC querySliceToQueryTranspose = {}; - querySliceToQueryTranspose.FromNodeIndex = NodeIndex::querySlice; - querySliceToQueryTranspose.FromNodeOutputIndex = 0; - querySliceToQueryTranspose.ToNodeIndex = NodeIndex::queryTranspose; - querySliceToQueryTranspose.ToNodeInputIndex = 0; - intermediateEdges.push_back(querySliceToQueryTranspose); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC keySliceToKeyTranspose = {}; - keySliceToKeyTranspose.FromNodeIndex = NodeIndex::keySlice; - keySliceToKeyTranspose.FromNodeOutputIndex = 0; - keySliceToKeyTranspose.ToNodeIndex = NodeIndex::keyTranspose; - keySliceToKeyTranspose.ToNodeInputIndex = 0; - intermediateEdges.push_back(keySliceToKeyTranspose); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC queryTransposeToGemm = {}; - queryTransposeToGemm.FromNodeIndex = NodeIndex::queryTranspose; - queryTransposeToGemm.FromNodeOutputIndex = 0; - queryTransposeToGemm.ToNodeIndex = NodeIndex::attentionScore; - queryTransposeToGemm.ToNodeInputIndex = 0; - intermediateEdges.push_back(queryTransposeToGemm); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC keyTransposeToGemm = {}; - keyTransposeToGemm.FromNodeIndex = NodeIndex::keyTranspose; - keyTransposeToGemm.FromNodeOutputIndex = 0; - keyTransposeToGemm.ToNodeIndex = NodeIndex::attentionScore; - keyTransposeToGemm.ToNodeInputIndex = 1; - intermediateEdges.push_back(keyTransposeToGemm); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC castedMaskIndexToIdentity = {}; - castedMaskIndexToIdentity.FromNodeIndex = NodeIndex::castMaskIndex; - castedMaskIndexToIdentity.FromNodeOutputIndex = 0; - castedMaskIndexToIdentity.ToNodeIndex = NodeIndex::mask; - castedMaskIndexToIdentity.ToNodeInputIndex = 0; - intermediateEdges.push_back(castedMaskIndexToIdentity); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC maskToGemm = {}; - maskToGemm.FromNodeIndex = NodeIndex::mask; - maskToGemm.FromNodeOutputIndex = 0; - maskToGemm.ToNodeIndex = NodeIndex::attentionScore; - maskToGemm.ToNodeInputIndex = 2; - intermediateEdges.push_back(maskToGemm); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC attentionScoreToSoftmax = {}; - attentionScoreToSoftmax.FromNodeIndex = NodeIndex::attentionScore; - attentionScoreToSoftmax.FromNodeOutputIndex = 0; - attentionScoreToSoftmax.ToNodeIndex = NodeIndex::softmax; - attentionScoreToSoftmax.ToNodeInputIndex = 0; - intermediateEdges.push_back(attentionScoreToSoftmax); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC softmaxToGemm = {}; - softmaxToGemm.FromNodeIndex = NodeIndex::softmax; - softmaxToGemm.FromNodeOutputIndex = 0; - softmaxToGemm.ToNodeIndex = NodeIndex::attentionWeight; - softmaxToGemm.ToNodeInputIndex = 0; - intermediateEdges.push_back(softmaxToGemm); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC valueSliceToValueTranspose = {}; - valueSliceToValueTranspose.FromNodeIndex = NodeIndex::valueSlice; - valueSliceToValueTranspose.FromNodeOutputIndex = 0; - valueSliceToValueTranspose.ToNodeIndex = NodeIndex::valueTranspose; - valueSliceToValueTranspose.ToNodeInputIndex = 0; - intermediateEdges.push_back(valueSliceToValueTranspose); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC valueTransposeToGemm = {}; - valueTransposeToGemm.FromNodeIndex = NodeIndex::valueTranspose; - valueTransposeToGemm.FromNodeOutputIndex = 0; - valueTransposeToGemm.ToNodeIndex = NodeIndex::attentionWeight; - valueTransposeToGemm.ToNodeInputIndex = 1; - intermediateEdges.push_back(valueTransposeToGemm); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC gemmToIdentity = {}; - gemmToIdentity.FromNodeIndex = NodeIndex::attentionWeight; - gemmToIdentity.FromNodeOutputIndex = 0; - gemmToIdentity.ToNodeIndex = NodeIndex::output; - gemmToIdentity.ToNodeInputIndex = 0; - intermediateEdges.push_back(gemmToIdentity); + DML_OUTPUT_GRAPH_EDGE_DESC mhaToOutputEdge = {}; + mhaToOutputEdge.FromNodeIndex = mhaNodeIndex; + mhaToOutputEdge.FromNodeOutputIndex = 0; + mhaToOutputEdge.GraphOutputIndex = 0; + outputEdges.push_back(mhaToOutputEdge); + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); + operatorGraphDesc.inputEdges = inputEdges.data(); operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size()); operatorGraphDesc.intermediateEdges = intermediateEdges.data(); - - // set the output edges - std::array outputEdges; - DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {}; - outputEdge.FromNodeIndex = NodeIndex::output; - outputEdge.FromNodeOutputIndex = 0; - outputEdge.GraphOutputIndex = 0; - outputEdges[0] = outputEdge; operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); + operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); + operatorGraphDesc.nodesAsOpDesc = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } @@ -448,32 +540,37 @@ class DmlOperatorAttention : public DmlOperator void CALLBACK QueryAttention(IMLOperatorSupportQueryContextPrivate* context, /*out*/ bool* isSupported) { *isSupported = false; - // Fall back to CPU if input 'past' and 'relative_position_bias' is present because there is no current use case for this. - // and it will make the implementation more complex. - // Also fall back to CPU if output 'present' is present for same reason as above. - if (context->GetInputCount() > 4 || context->GetOutputCount() > 1) + // `past` input tensor is not supported yet + if (context->IsInputValid(4)) { return; } - // Checking input count alone is not sufficient to fallback to CPU if input 'past' and 'relative_position_bias' is present - // because input 'mask_index', 'past', and 'relative_position_bias' all are optional. - if (context->IsInputValid(4) || context->IsInputValid(5)) + + // `past_sequence_length` input tensor is not supported yet + if (context->IsInputValid(6)) { return; } - // Fall back to CPU if attibute 'qkv_hidden_sizes' is present or - // if value of attribute 'unidirectional' is 1, because of same reason as above. - MLOperatorAttributes attributes(context); - if (attributes.HasAttribute(AttrName::QkvHiddenSizes, MLOperatorAttributeType::IntArray)) + + // `present` output tensor is not supported yet + if (context->IsOutputValid(1)) { return; } + // `unidirectional == 1` is not supported yet + MLOperatorAttributes attributes(context); if (attributes.GetOptionalAttribute(AttrName::Unidirectional, 0) != 0) { return; } + // `do_rotary == 1` is not supported yet + if (attributes.GetOptionalAttribute(AttrName::DoRotary, 0) != 0) + { + return; + } + *isSupported = true; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp new file mode 100644 index 0000000000..9c1a7baeaa --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp @@ -0,0 +1,281 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ +class DmlOperatorMultiHeadAttention : public DmlOperator +{ +public: + DmlOperatorMultiHeadAttention(const MLOperatorKernelCreationContext& kernelCreationContext) + : DmlOperator(kernelCreationContext) + { + enum InputIndex : uint32_t + { + queryIndex, + keyIndex, + valueIndex, + biasIndex, + maskIndex, + relativePositionBiasIndex, + pastKeyIndex, + pastValueIndex, + inputCount, + }; + + enum DmlInputIndex : uint32_t + { + dmlQueryIndex, + dmlKeyIndex, + dmlValueIndex, + dmlStackedQueryKeyIndex, + dmlStackedKeyValueIndex, + dmlStackedQueryKeyValueIndex, + dmlBiasIndex, + dmlMaskIndex, + dmlRelativePositionBiasIndex, + dmlPastKeyIndex, + dmlPastValueIndex, + dmlInputCount, + }; + + enum OutputIndex : uint32_t + { + outputIndex, + outputPresentKeyIndex, + outputPresentValueIndex, + outputCount, + }; + + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 1); + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() >= 1); + + const bool keyValueIsPast = kernelCreationContext.IsInputValid(keyIndex) && kernelCreationContext.GetInputTensorDimensionCount(keyIndex) == 4; + const bool hasValue = kernelCreationContext.IsInputValid(valueIndex) && !keyValueIsPast; + const bool hasBias = kernelCreationContext.IsInputValid(biasIndex); + const bool hasMask = kernelCreationContext.IsInputValid(maskIndex); + const bool hasRelativePositionBias = kernelCreationContext.IsInputValid(relativePositionBiasIndex); + const bool hasPastKey = keyValueIsPast || kernelCreationContext.IsInputValid(pastKeyIndex); + const bool hasPastValue = keyValueIsPast || kernelCreationContext.IsInputValid(pastValueIndex); + const bool hasPresentKeyOutput = kernelCreationContext.IsOutputValid(outputPresentKeyIndex); + const bool hasPresentValueOutput = kernelCreationContext.IsOutputValid(outputPresentValueIndex); + const bool stackedQkv = kernelCreationContext.GetInputTensorDimensionCount(queryIndex) == 5; + const bool stackedKv = kernelCreationContext.IsInputValid(keyIndex) && kernelCreationContext.GetInputTensorDimensionCount(keyIndex) == 5; + const bool hasKey = !stackedKv && !keyValueIsPast && kernelCreationContext.IsInputValid(keyIndex); + + std::vector> inputIndices = { + stackedQkv ? std::nullopt : std::optional(queryIndex), + hasKey ? std::optional(keyIndex) : std::nullopt, + hasValue ? std::optional(valueIndex) : std::nullopt, + std::nullopt, + stackedKv ? std::optional(keyIndex) : std::nullopt, + stackedQkv ? std::optional(queryIndex) : std::nullopt, + biasIndex, + hasMask ? std::optional(maskIndex) : std::nullopt, + relativePositionBiasIndex, + keyValueIsPast ? keyIndex : pastKeyIndex, + keyValueIsPast ? valueIndex : pastValueIndex, + }; + + std::vector> outputIndices = { + outputIndex, + outputPresentKeyIndex, + outputPresentValueIndex, + }; + DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices, std::nullopt, std::nullopt, 1); + + ML_CHECK_VALID_ARGUMENT(!stackedQkv || m_inputTensorDescs[dmlStackedQueryKeyValueIndex].GetDimensionCount() == 5); + ML_CHECK_VALID_ARGUMENT(stackedQkv || m_inputTensorDescs[dmlQueryIndex].GetDimensionCount() == 3); + ML_CHECK_VALID_ARGUMENT(!hasKey || m_inputTensorDescs[dmlKeyIndex].GetDimensionCount() == 3); + ML_CHECK_VALID_ARGUMENT(!hasValue || m_inputTensorDescs[dmlValueIndex].GetDimensionCount() == 3); + ML_CHECK_VALID_ARGUMENT(!hasPastKey || m_inputTensorDescs[dmlPastKeyIndex].GetDimensionCount() == 4); + ML_CHECK_VALID_ARGUMENT(!hasPastValue || m_inputTensorDescs[dmlPastValueIndex].GetDimensionCount() == 4); + + const uint32_t batchSize = stackedQkv + ? m_inputTensorDescs[dmlStackedQueryKeyValueIndex].GetSizes()[0] + : m_inputTensorDescs[dmlQueryIndex].GetSizes()[0]; + + const uint32_t numHeads = gsl::narrow_cast(kernelCreationContext.GetAttribute(AttrName::NumHeads)); + const uint32_t headSize = stackedQkv + ? m_inputTensorDescs[dmlStackedQueryKeyValueIndex].GetSizes()[4] + : m_inputTensorDescs[dmlQueryIndex].GetSizes()[2] / numHeads; + + const uint32_t sequenceLength = stackedQkv + ? m_inputTensorDescs[dmlStackedQueryKeyValueIndex].GetSizes()[1] + : m_inputTensorDescs[dmlQueryIndex].GetSizes()[1]; + + uint32_t kvSequenceLength; + if (hasKey) + { + kvSequenceLength = m_inputTensorDescs[dmlKeyIndex].GetSizes()[1]; + } + else if (stackedKv) + { + kvSequenceLength = m_inputTensorDescs[dmlStackedKeyValueIndex].GetSizes()[1]; + } + else if (hasPastKey) + { + kvSequenceLength = m_inputTensorDescs[dmlPastKeyIndex].GetSizes()[2]; + } + else + { + kvSequenceLength = sequenceLength; + } + + const uint32_t hiddenSize = numHeads * headSize; + const uint32_t vHiddenSize = hasValue ? m_inputTensorDescs[dmlValueIndex].GetSizes()[2] : hiddenSize; + const uint32_t pastSequenceLength = hasPastKey ? m_inputTensorDescs[dmlPastKeyIndex].GetSizes()[2] : 0; + const uint32_t totalSequenceLength = kvSequenceLength + pastSequenceLength; + + if (stackedQkv) + { + auto stackedQkvSizes = m_inputTensorDescs[dmlStackedQueryKeyValueIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(stackedQkvSizes[0] == batchSize); + ML_CHECK_VALID_ARGUMENT(stackedQkvSizes[1] == sequenceLength); + ML_CHECK_VALID_ARGUMENT(stackedQkvSizes[2] == numHeads); + ML_CHECK_VALID_ARGUMENT(stackedQkvSizes[3] == 3); + ML_CHECK_VALID_ARGUMENT(stackedQkvSizes[4] == headSize); + } + else + { + auto querySizes = m_inputTensorDescs[dmlQueryIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(querySizes[0] == batchSize); + ML_CHECK_VALID_ARGUMENT(querySizes[1] == sequenceLength); + ML_CHECK_VALID_ARGUMENT(querySizes[2] == hiddenSize); + } + + if (hasKey) + { + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[dmlKeyIndex].GetDimensionCount() == 3); + + auto keySizes = m_inputTensorDescs[dmlKeyIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(keySizes[0] == batchSize); + ML_CHECK_VALID_ARGUMENT(keySizes[1] == kvSequenceLength); + ML_CHECK_VALID_ARGUMENT(keySizes[2] == hiddenSize); + } + + if (hasValue) + { + auto valueSizes = m_inputTensorDescs[dmlValueIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(valueSizes[0] == batchSize); + ML_CHECK_VALID_ARGUMENT(valueSizes[1] == kvSequenceLength); + ML_CHECK_VALID_ARGUMENT(valueSizes[2] == vHiddenSize); + } + + if (stackedKv) + { + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[dmlStackedKeyValueIndex].GetDimensionCount() == 5); + + auto stackedKvSizes = m_inputTensorDescs[dmlStackedKeyValueIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(stackedKvSizes[0] == batchSize); + ML_CHECK_VALID_ARGUMENT(stackedKvSizes[1] == kvSequenceLength); + ML_CHECK_VALID_ARGUMENT(stackedKvSizes[2] == numHeads); + ML_CHECK_VALID_ARGUMENT(stackedKvSizes[3] == 2); + ML_CHECK_VALID_ARGUMENT(stackedKvSizes[4] == headSize); + } + + if (hasBias) + { + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[dmlBiasIndex].GetDimensionCount() == 1); + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[dmlBiasIndex].GetSizes()[0] == hiddenSize + hiddenSize + vHiddenSize); + } + + DML_MULTIHEAD_ATTENTION_MASK_TYPE maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE; + if (hasMask) + { + if (kernelCreationContext.GetInputTensorDimensionCount(maskIndex) == 1) + { + const auto unpaddedKeyBoundsShape = m_inputTensorDescs[dmlMaskIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(unpaddedKeyBoundsShape.size() == 1); + ML_CHECK_VALID_ARGUMENT(unpaddedKeyBoundsShape[0] == batchSize || unpaddedKeyBoundsShape[0] == batchSize * 3 + 2); + + maskType = unpaddedKeyBoundsShape[0] == batchSize + ? DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH + : DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END; + + if (maskType == DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH) + { + uint32_t desiredShape[2] = {1, batchSize}; + m_inputTensorDescs[dmlMaskIndex] = TensorDesc( + m_inputTensorDescs[dmlMaskIndex].GetDmlDataType(), + desiredShape); + } + } + else + { + const auto keyPaddingMaskTensorShape = m_inputTensorDescs[dmlMaskIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape.size() == 2); + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[0] == batchSize); + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == kvSequenceLength); + + const uint32_t actualShape[4] = {batchSize, 1, 1, kvSequenceLength}; + const uint32_t desiredShape[4] = {batchSize, numHeads, sequenceLength, kvSequenceLength}; + + m_inputTensorDescs[dmlMaskIndex] = TensorDesc::ConstructBroadcastedTensorDesc( + m_inputTensorDescs[dmlMaskIndex].GetMlOperatorDataType(), + desiredShape, + actualShape); + + maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN; + } + } + + if (hasRelativePositionBias) + { + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[dmlRelativePositionBiasIndex].GetDimensionCount() == 4); + + auto relativePositionBiasSizes = m_inputTensorDescs[dmlRelativePositionBiasIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(relativePositionBiasSizes[0] == batchSize); + ML_CHECK_VALID_ARGUMENT(relativePositionBiasSizes[1] == numHeads); + ML_CHECK_VALID_ARGUMENT(relativePositionBiasSizes[2] == sequenceLength); + ML_CHECK_VALID_ARGUMENT(relativePositionBiasSizes[3] == totalSequenceLength); + } + + if (hasPastKey) + { + auto pastKeySizes = m_inputTensorDescs[dmlPastKeyIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(pastKeySizes[0] == batchSize); + ML_CHECK_VALID_ARGUMENT(pastKeySizes[1] == numHeads); + ML_CHECK_VALID_ARGUMENT(pastKeySizes[2] == pastSequenceLength); + ML_CHECK_VALID_ARGUMENT(pastKeySizes[3] == headSize); + } + + if (hasPastValue) + { + auto pastValueSizes = m_inputTensorDescs[dmlPastValueIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(pastValueSizes[0] == batchSize); + ML_CHECK_VALID_ARGUMENT(pastValueSizes[1] == numHeads); + ML_CHECK_VALID_ARGUMENT(pastValueSizes[2] == pastSequenceLength); + ML_CHECK_VALID_ARGUMENT(pastValueSizes[3] == headSize); + } + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_MULTIHEAD_ATTENTION_OPERATOR_DESC mhaDesc = {}; + mhaDesc.QueryTensor = stackedQkv ? nullptr : &inputDescs[dmlQueryIndex]; + mhaDesc.KeyTensor = hasKey ? &inputDescs[dmlKeyIndex] : nullptr; + mhaDesc.ValueTensor = hasValue ? &inputDescs[dmlValueIndex] : nullptr; + mhaDesc.StackedKeyValueTensor = stackedKv ? &inputDescs[dmlStackedKeyValueIndex] : nullptr; + mhaDesc.StackedQueryKeyValueTensor = stackedQkv ? &inputDescs[dmlStackedQueryKeyValueIndex] : nullptr; + mhaDesc.BiasTensor = hasBias ? &inputDescs[dmlBiasIndex] : nullptr; + mhaDesc.MaskTensor = hasMask ? &inputDescs[dmlMaskIndex] : nullptr; + mhaDesc.RelativePositionBiasTensor = hasRelativePositionBias ? &inputDescs[dmlRelativePositionBiasIndex] : nullptr; + mhaDesc.PastKeyTensor = hasPastKey ? &inputDescs[dmlPastKeyIndex] : nullptr; + mhaDesc.PastValueTensor = hasPastValue ? &inputDescs[dmlPastValueIndex] : nullptr; + mhaDesc.OutputTensor = &outputDescs[outputIndex]; + mhaDesc.OutputPresentKeyTensor = hasPresentKeyOutput ? &outputDescs[outputPresentKeyIndex] : nullptr; + mhaDesc.OutputPresentValueTensor = hasPresentValueOutput ? &outputDescs[outputPresentValueIndex] : nullptr; + mhaDesc.Scale = kernelCreationContext.GetOptionalAttribute(AttrName::Scale, gsl::narrow_cast(1.0f / std::sqrt(headSize))); + mhaDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute(AttrName::MaskFilterValue, -10'000.0f); + mhaDesc.HeadCount = numHeads; + mhaDesc.MaskType = maskType; + + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_MULTIHEAD_ATTENTION, &mhaDesc }; + SetDmlOperatorDesc(opDesc, kernelCreationContext); + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(MultiHeadAttention, DmlOperatorMultiHeadAttention); +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp index c3a25ca8d4..892efca305 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRoiAlign.cpp @@ -29,14 +29,25 @@ class DmlOperatorRegionOfInterestAlign : public DmlOperator, public RoiAlignHelp {"max", DML_REDUCE_FUNCTION_MAX}, {"avg", DML_REDUCE_FUNCTION_AVERAGE}, }; + + constexpr NameAndIndex coordinateTransformationModes[] = + { + {"half_pixel", 0}, + {"output_half_pixel", 1}, + }; + + std::string coordinateTransformationMode = kernelCreationContext.GetOptionalAttribute(AttrName::CoordinateTransformationMode, "half_pixel"); + auto optionalCoordinateTransformationModeValue = TryMapStringToIndex(coordinateTransformationMode, coordinateTransformationModes); const std::string mode = kernelCreationContext.GetOptionalAttribute(AttrName::Mode, "avg"); const auto optionalReductionFunction = TryMapStringToIndex(mode, mapping); const float spatialScale = kernelCreationContext.GetOptionalAttribute(AttrName::SpatialScale, 1.0f); const int32_t samplesPerOutput = kernelCreationContext.GetOptionalAttribute(AttrName::SamplingRatio, 0u); ML_CHECK_VALID_ARGUMENT(samplesPerOutput >= 0, "sampling_ratio must be 0 or positive."); ML_CHECK_VALID_ARGUMENT(!!optionalReductionFunction, "Unsupported RoiAlign mode."); + ML_CHECK_VALID_ARGUMENT(!!optionalCoordinateTransformationModeValue, "Unsupported RoiAlign coordinate_transformation_mode."); + - DML_ROI_ALIGN_OPERATOR_DESC operatorDesc = {}; + DML_ROI_ALIGN1_OPERATOR_DESC operatorDesc = {}; operatorDesc.InputTensor = &inputDescs[0]; operatorDesc.ROITensor = &inputDescs[1]; operatorDesc.BatchIndicesTensor = &inputDescs[2]; @@ -48,12 +59,15 @@ class DmlOperatorRegionOfInterestAlign : public DmlOperator, public RoiAlignHelp operatorDesc.MaximumSamplesPerOutput = (samplesPerOutput == 0) ? UINT32_MAX : samplesPerOutput; operatorDesc.ReductionFunction = *optionalReductionFunction; operatorDesc.InterpolationMode = DML_INTERPOLATION_MODE_LINEAR; - DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN, &operatorDesc }; + operatorDesc.InputPixelOffset = (*optionalCoordinateTransformationModeValue == 0)? 0.5f : 0.0f; + operatorDesc.OutputPixelOffset = -0.5f; + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ROI_ALIGN1, &operatorDesc }; SetDmlOperatorDesc(opDesc, kernelCreationContext); } }; DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign10, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(RoiAlign16, VersionedKernel); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/GenerateShaders.bat b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/GenerateShaders.bat index 64a0ff5737..fb087bd800 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/GenerateShaders.bat +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/GenerateShaders.bat @@ -7,10 +7,78 @@ if "%1" == "DEBUG" ( fxc.exe ..\Shaders\bluestein_chirp.hlsl -E BluesteinZChirp -T cs_5_0 /DTBUFFER=float /Zi /Od /Fh bluestein_chirp.h dxc.exe ..\Shaders\bluestein_chirp.hlsl -E BluesteinZChirp -T cs_6_2 -DTBUFFER=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh bluestein_chirp_fp16.h + + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=float -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint16_float.h + fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=uint /DTBUFFER2=float /Zi /Od /Fh grid_sample_uint_float.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint64_t -DTBUFFER2=float -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint64_float.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int16_t -DTBUFFER2=float -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int16_float.h + fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=int /DTBUFFER2=float /Zi /Od /Fh grid_sample_int_float.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=float -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int64_float.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=float -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_fp16_float.h + fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=float /DTBUFFER2=float /Zi /Od /Fh grid_sample_float_float.h + fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=double /DTBUFFER2=float /Zi /Od /Fh grid_sample_double_float.h + fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=bool /DTBUFFER2=float /Zi /Od /Fh grid_sample_bool_float.h + + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint16_fp16.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint_fp16.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint64_t -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint64_fp16.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int16_t -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int16_fp16.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int_fp16.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int64_fp16.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_fp16_fp16.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_float_fp16.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=double -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_double_fp16.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=bool -DTBUFFER2=float16_t -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_bool_fp16.h + + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint16_double.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint_double.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint64_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_uint64_double.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int16_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int16_double.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int_double.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_int64_double.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_fp16_double.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_float_double.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=double -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_double_double.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=bool -DTBUFFER2=double -enable-16bit-types -Zi -Od -Qembed_debug -Fh grid_sample_bool_double.h + ) else ( fxc.exe ..\Shaders\stockham.hlsl -E DFT -T cs_5_0 /DTBUFFER=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh stockham.h dxc.exe ..\Shaders\stockham.hlsl -E DFT -T cs_6_2 -DTBUFFER=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh stockham_fp16.h fxc.exe ..\Shaders\bluestein_chirp.hlsl -E BluesteinZChirp -T cs_5_0 /DTBUFFER=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh bluestein_chirp.h dxc.exe ..\Shaders\bluestein_chirp.hlsl -E BluesteinZChirp -T cs_6_2 -DTBUFFER=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh bluestein_chirp_fp16.h + + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=float -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint16_float.h + fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=uint /DTBUFFER2=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh grid_sample_uint_float.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint64_t -DTBUFFER2=float -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint64_float.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int16_t -DTBUFFER2=float -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int16_float.h + fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=int /DTBUFFER2=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh grid_sample_int_float.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=float -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int64_float.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=float -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_fp16_float.h + fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=float /DTBUFFER2=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh grid_sample_float_float.h + fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=double /DTBUFFER2=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh grid_sample_double_float.h + fxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_5_0 /DTBUFFER1=bool /DTBUFFER2=float /O3 /Qstrip_reflect /Qstrip_debug /Qstrip_rootsignature /Qstrip_priv /Fh grid_sample_bool_float.h + + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint16_fp16.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint_fp16.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint64_t -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint64_fp16.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int16_t -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int16_fp16.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int_fp16.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int64_fp16.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_fp16_fp16.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_float_fp16.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=double -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_double_fp16.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=bool -DTBUFFER2=float16_t -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_bool_fp16.h + + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint16_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint16_double.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint_double.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=uint64_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_uint64_double.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int16_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int16_double.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int_double.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=int64_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_int64_double.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float16_t -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_fp16_double.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=float -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_float_double.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=double -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_double_double.h + dxc.exe ..\Shaders\grid_sample.hlsl -E GridSample -T cs_6_2 -DTBUFFER1=bool -DTBUFFER2=double -enable-16bit-types -O3 -Qstrip_reflect -Qstrip_debug -Qstrip_rootsignature -Fh grid_sample_bool_double.h + ) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/grid_sample_bool_double.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/grid_sample_bool_double.h new file mode 100644 index 0000000000..6d83865ecb --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/grid_sample_bool_double.h @@ -0,0 +1,6398 @@ +#if 0 +; +; Note: shader requires additional functionality: +; Double-precision floating point +; +; +; Input signature: +; +; Name Index Mask Register SysValue Format Used +; -------------------- ----- ------ -------- -------- ------- ------ +; no parameters +; +; Output signature: +; +; Name Index Mask Register SysValue Format Used +; -------------------- ----- ------ -------- -------- ------- ------ +; no parameters +; shader hash: eff86a5dd3f8ca652b3700c52570e535 +; +; Pipeline Runtime Information: +; +; +; +; Buffer Definitions: +; +; cbuffer +; { +; +; [116 x i8] (type annotation not present) +; +; } +; +; Resource bind info for +; { +; +; [4 x i8] (type annotation not present) +; +; } +; +; Resource bind info for +; { +; +; [8 x i8] (type annotation not present) +; +; } +; +; Resource bind info for +; { +; +; [4 x i8] (type annotation not present) +; +; } +; +; +; Resource Bindings: +; +; Name Type Format Dim ID HLSL Bind Count +; ------------------------------ ---------- ------- ----------- ------- -------------- ------ +; cbuffer NA NA CB0 cb0 1 +; UAV struct r/w U0 u0 1 +; UAV struct r/w U1 u1 1 +; UAV struct r/w U2 u2 1 +; +target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64" +target triple = "dxil-ms-dx" + +%dx.types.Handle = type { i8* } +%dx.types.CBufRet.i32 = type { i32, i32, i32, i32 } +%dx.types.ResRet.i32 = type { i32, i32, i32, i32, i32 } +%"class.RWStructuredBuffer" = type { i32 } +%"class.RWStructuredBuffer" = type { double } +%Constants = type { i32, i32, i32, i32, <4 x i32>, <4 x i32>, <4 x i32>, <4 x i32>, <4 x i32>, <4 x i32>, i32 } + +define void @GridSample() { + %1 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 2, i32 2, i1 false) ; CreateHandle(resourceClass,rangeId,index,nonUniformIndex) + %2 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 1, i32 1, i1 false) ; CreateHandle(resourceClass,rangeId,index,nonUniformIndex) + %3 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 0, i32 0, i1 false) ; CreateHandle(resourceClass,rangeId,index,nonUniformIndex) + %4 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 2, i32 0, i32 0, i1 false) ; CreateHandle(resourceClass,rangeId,index,nonUniformIndex) + %5 = call i32 @dx.op.threadId.i32(i32 93, i32 0) ; ThreadId(component) + %6 = call %dx.types.CBufRet.i32 @dx.op.cbufferLoadLegacy.i32(i32 59, %dx.types.Handle %4, i32 0) ; CBufferLoadLegacy(handle,regIndex) + %7 = extractvalue %dx.types.CBufRet.i32 %6, 0 + %8 = add i32 %7, %5 + %9 = extractvalue %dx.types.CBufRet.i32 %6, 1 + %10 = icmp ult i32 %8, %9 + br i1 %10, label %11, label %3389 + +;