diff --git a/.gdn/.gdntsa b/.gdn/.gdntsa
index 2992cab431..e49848d116 100644
--- a/.gdn/.gdntsa
+++ b/.gdn/.gdntsa
@@ -1,3 +1,3 @@
{
- "codebaseName": "onnxruntime_master"
+ "codebaseName": "onnxruntime_main"
}
\ No newline at end of file
diff --git a/.github/workflows/generate-skip-doc-change.py b/.github.upstream/generate-skip-doc-change.py
similarity index 100%
rename from .github/workflows/generate-skip-doc-change.py
rename to .github.upstream/generate-skip-doc-change.py
diff --git a/.github/workflows/skip-doc-change.yml.j2 b/.github.upstream/skip-doc-change.yml.j2
similarity index 100%
rename from .github/workflows/skip-doc-change.yml.j2
rename to .github.upstream/skip-doc-change.yml.j2
diff --git a/.github/workflows/cffconvert.yml b/.github.upstream/workflows/cffconvert.yml
similarity index 100%
rename from .github/workflows/cffconvert.yml
rename to .github.upstream/workflows/cffconvert.yml
diff --git a/.github/workflows/codeql.yml b/.github.upstream/workflows/codeql.yml
similarity index 100%
rename from .github/workflows/codeql.yml
rename to .github.upstream/workflows/codeql.yml
diff --git a/.github/workflows/generated_fake_win_gpu_ci.yml b/.github.upstream/workflows/generated_fake_win_gpu_ci.yml
similarity index 100%
rename from .github/workflows/generated_fake_win_gpu_ci.yml
rename to .github.upstream/workflows/generated_fake_win_gpu_ci.yml
diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github.upstream/workflows/gradle-wrapper-validation.yml
similarity index 100%
rename from .github/workflows/gradle-wrapper-validation.yml
rename to .github.upstream/workflows/gradle-wrapper-validation.yml
diff --git a/.github/workflows/labeler.yml b/.github.upstream/workflows/labeler.yml
similarity index 100%
rename from .github/workflows/labeler.yml
rename to .github.upstream/workflows/labeler.yml
diff --git a/.github/workflows/lint.yml b/.github.upstream/workflows/lint.yml
similarity index 99%
rename from .github/workflows/lint.yml
rename to .github.upstream/workflows/lint.yml
index 83d2a4bfd6..91f9a8ee3d 100644
--- a/.github/workflows/lint.yml
+++ b/.github.upstream/workflows/lint.yml
@@ -3,8 +3,8 @@ name: Lint
on:
push:
branches:
- - master
- main
+ - rel-*
pull_request:
jobs:
diff --git a/.github/workflows/linux.yml b/.github.upstream/workflows/linux.yml
similarity index 88%
rename from .github/workflows/linux.yml
rename to .github.upstream/workflows/linux.yml
index 92ddf900d9..4c84406fb7 100644
--- a/.github/workflows/linux.yml
+++ b/.github.upstream/workflows/linux.yml
@@ -2,10 +2,14 @@ name: Linux_CI
on:
push:
branches:
- - master
- main
+ - rel-*
pull_request:
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
jobs:
Onnxruntime-TVM:
runs-on: ubuntu-latest
diff --git a/.github/workflows/publish-c-apidocs.yml b/.github.upstream/workflows/publish-c-apidocs.yml
similarity index 100%
rename from .github/workflows/publish-c-apidocs.yml
rename to .github.upstream/workflows/publish-c-apidocs.yml
diff --git a/.github/workflows/publish-csharp-apidocs.yml b/.github.upstream/workflows/publish-csharp-apidocs.yml
similarity index 98%
rename from .github/workflows/publish-csharp-apidocs.yml
rename to .github.upstream/workflows/publish-csharp-apidocs.yml
index f038b66dec..eb327fc645 100644
--- a/.github/workflows/publish-csharp-apidocs.yml
+++ b/.github.upstream/workflows/publish-csharp-apidocs.yml
@@ -28,7 +28,7 @@ jobs:
- name: Setup .NET
uses: actions/setup-dotnet@v2
with:
- dotnet-version: 5.0.x
+ dotnet-version: 6.0.x
- name: Restore dependencies
run: dotnet restore csharp/ApiDocs/ApiDocs.csproj
- name: Download DocFX
diff --git a/.github/workflows/publish-java-apidocs.yml b/.github.upstream/workflows/publish-java-apidocs.yml
similarity index 100%
rename from .github/workflows/publish-java-apidocs.yml
rename to .github.upstream/workflows/publish-java-apidocs.yml
diff --git a/.github/workflows/publish-python-apidocs.yml b/.github.upstream/workflows/publish-python-apidocs.yml
similarity index 80%
rename from .github/workflows/publish-python-apidocs.yml
rename to .github.upstream/workflows/publish-python-apidocs.yml
index c0a49e348c..5f3e172fc7 100644
--- a/.github/workflows/publish-python-apidocs.yml
+++ b/.github.upstream/workflows/publish-python-apidocs.yml
@@ -35,19 +35,19 @@ jobs:
python3 -m pip install --upgrade pip
cd docs/python
python3 -m pip install -r requirements.txt
- python3 -m pip install -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ ort-nightly
+ python3 -m pip install --pre onnxruntime-training -f https://download.onnxruntime.ai/onnxruntime_nightly_cpu.html
python3 -m pip list
- name: Generate Python docs with Sphinx
run: |
cd tools/doc
./builddoc.sh /usr/bin ../.. ../../build
- name: Log source commit
- run: git rev-parse --short HEAD > build/docs/inference/html/source-version.txt
+ run: git rev-parse --short HEAD > build/docs/html/source-version.txt
- name: Move Python docs into site
run: |
rm -rf _site/docs/api/python
- mkdir -p _site/docs/api
- mv build/docs/inference/html _site/docs/api/python
+ mkdir -p _site/docs/api/
+ mv build/docs/html _site/docs/api/python
- name: Upload docs artifact
uses: actions/upload-artifact@v3
with:
diff --git a/.github/workflows/windows.yml b/.github.upstream/workflows/windows.yml
similarity index 89%
rename from .github/workflows/windows.yml
rename to .github.upstream/workflows/windows.yml
index f2c61d3359..ffa91aa221 100644
--- a/.github/workflows/windows.yml
+++ b/.github.upstream/workflows/windows.yml
@@ -2,10 +2,14 @@ name: Windows_CI
on:
push:
branches:
- - master
- main
+ - rel-*
pull_request:
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
jobs:
Onnxruntime-TVM:
runs-on: windows-2019
diff --git a/.github/ISSUE_TEMPLATE/04-web.yml b/.github/ISSUE_TEMPLATE/04-web.yml
index 15919a5983..d84226ff5b 100644
--- a/.github/ISSUE_TEMPLATE/04-web.yml
+++ b/.github/ISSUE_TEMPLATE/04-web.yml
@@ -55,8 +55,10 @@ body:
attributes:
label: Execution Provider
options:
- - WebGL
- - WASM
+ - "'webgl' (WebGL)"
+ - "'wasm'/'cpu' (WebAssembly CPU)"
+ - "'xnnpack' (WebAssembly XNNPACK)"
+ - "'webgpu' (WebGPU)"
- Other / Unknown
multiple: yes
validations:
diff --git a/.github/workflows/publish-gh-pages.yml b/.github/workflows/publish-gh-pages.yml
index 5ddb1e3bb0..0f1e4cf840 100644
--- a/.github/workflows/publish-gh-pages.yml
+++ b/.github/workflows/publish-gh-pages.yml
@@ -98,4 +98,4 @@ jobs:
steps:
- name: Deploy to GitHub Pages
id: deployment
- uses: actions/deploy-pages@v1
+ uses: actions/deploy-pages@v2
diff --git a/.github/workflows/publish-objectivec-apidocs.yml b/.github/workflows/publish-objectivec-apidocs.yml
index 518ea8e69a..b966793cc0 100644
--- a/.github/workflows/publish-objectivec-apidocs.yml
+++ b/.github/workflows/publish-objectivec-apidocs.yml
@@ -21,7 +21,7 @@ permissions:
jobs:
build:
name: Generate Objective-C API docs
- runs-on: macos-12
+ runs-on: macos-13
steps:
- uses: actions/checkout@v3
diff --git a/.github/workflows/sca.yml b/.github/workflows/sca.yml
new file mode 100644
index 0000000000..4b19e71744
--- /dev/null
+++ b/.github/workflows/sca.yml
@@ -0,0 +1,133 @@
+name: Windows_SCA
+on:
+ push:
+ branches:
+ - main
+ - rel-*
+ pull_request:
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+env:
+ AZCOPY_AUTO_LOGIN_TYPE: MSI
+ AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4
+
+jobs:
+ Onnxruntime-SCA-training-CUDA:
+ runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"]
+ steps:
+ - uses: actions/checkout@v3
+ with:
+ submodules: false
+ - uses: actions/setup-python@v3
+ with:
+ python-version: '3.11.x'
+ architecture: 'x64'
+
+ - uses: actions/setup-node@v3
+ with:
+ node-version: 18
+
+ - name: Download cuda
+ run: azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v11.8" cuda_sdk
+
+
+ - name: Delete build folder
+ run: |
+ if (Test-Path D:\b) { Remove-Item -Recurse -Force D:\b }
+ &tools\ci_build\github\windows\install_third_party_deps.ps1 -cpu_arch x64 -install_prefix D:\b\Debug\installed -build_config Debug
+
+ # The build machine doesn't have a GPU. So the value of CMAKE_CUDA_ARCHITECTURES doesn't matter.
+ - name: Build code
+ env:
+ CAExcludePath: 'C:\Program Files;D:\b;${{ github.workspace }}\cmake'
+ run: python tools\ci_build\build.py --enable_training --build_java --compile_no_warning_as_error --config Debug --build_dir D:\b --skip_submodule_sync --build_csharp --update --build --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_pybind --cmake_extra_defines onnxruntime_USE_CUSTOM_STATIC_ANALYSIS_RULES=ON --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON --cmake_extra_defines onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE=ON --use_cuda --cuda_home=${{ github.workspace }}\cuda_sdk\v11.8 --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75
+
+ - name: Generate sarif
+ working-directory: D:\b
+ run: npx @microsoft/sarif-multitool merge *.sarif --recurse --output-directory=${{ github.workspace }}\output --output-file=MergeResult.sarif --merge-runs && dir ${{ github.workspace }}\output
+
+ - name: Upload SARIF to GitHub
+ uses: github/codeql-action/upload-sarif@v2
+ continue-on-error: true
+ with:
+ sarif_file: ${{ github.workspace }}\output\MergeResult.sarif
+ category: VS_SCA
+
+ # No python
+ Onnxruntime-SCA-win32-WINML-x64:
+ runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"]
+ steps:
+ - uses: actions/checkout@v3
+ with:
+ submodules: false
+ - uses: actions/setup-python@v3
+ with:
+ python-version: '3.11.x'
+ architecture: 'x64'
+
+ - uses: actions/setup-node@v3
+ with:
+ node-version: 18
+
+ - name: Delete build folder
+ run: |
+ if (Test-Path D:\b) { Remove-Item -Recurse -Force D:\b }
+ &tools\ci_build\github\windows\install_third_party_deps.ps1 -cpu_arch x64 -install_prefix D:\b\Debug\installed -build_config Debug
+
+ # The build machine doesn't have a GPU. So the value of CMAKE_CUDA_ARCHITECTURES doesn't matter.
+ - name: Build code
+ env:
+ CAExcludePath: 'C:\Program Files;D:\b;${{ github.workspace }}\cmake'
+ run: python tools\ci_build\build.py --enable_training --build_java --compile_no_warning_as_error --config Debug --build_dir D:\b --skip_submodule_sync --build_csharp --update --build --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --cmake_extra_defines onnxruntime_USE_CUSTOM_STATIC_ANALYSIS_RULES=ON --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON --cmake_extra_defines onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE=ON --ms_experimental --use_dml --use_winml --disable_rtti --enable_wcos --build_shared_lib
+
+ - name: Generate sarif
+ working-directory: D:\b
+ run: npx @microsoft/sarif-multitool merge *.sarif --recurse --output-directory=${{ github.workspace }}\output --output-file=MergeResult.sarif --merge-runs && dir ${{ github.workspace }}\output
+
+ - name: Upload SARIF to GitHub
+ uses: github/codeql-action/upload-sarif@v2
+ continue-on-error: true
+ with:
+ sarif_file: ${{ github.workspace }}\output\MergeResult.sarif
+ category: VS_SCA_WIN32_WINML_X64
+
+ # No java, No python
+ Onnxruntime-SCA-win32-WINML-x86:
+ runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"]
+ steps:
+ - uses: actions/checkout@v3
+ with:
+ submodules: false
+ - uses: actions/setup-python@v3
+ with:
+ python-version: '3.11.x'
+ architecture: 'x86'
+
+ - uses: actions/setup-node@v3
+ with:
+ node-version: 18
+
+ - name: Delete build folder
+ run: |
+ if (Test-Path D:\b) { Remove-Item -Recurse -Force D:\b }
+ &tools\ci_build\github\windows\install_third_party_deps.ps1 -cpu_arch x86 -install_prefix D:\b\Debug\installed -build_config Debug
+
+ # The build machine doesn't have a GPU. So the value of CMAKE_CUDA_ARCHITECTURES doesn't matter.
+ - name: Build code
+ env:
+ CAExcludePath: 'C:\Program Files;D:\b;${{ github.workspace }}\cmake'
+ run: python tools\ci_build\build.py --enable_training --compile_no_warning_as_error --config Debug --build_dir D:\b --skip_submodule_sync --build_csharp --update --build --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --cmake_extra_defines onnxruntime_USE_CUSTOM_STATIC_ANALYSIS_RULES=ON --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON --cmake_extra_defines onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE=ON --ms_experimental --use_dml --use_winml --disable_rtti --enable_wcos --build_shared_lib
+
+ - name: Generate sarif
+ working-directory: D:\b
+ run: npx @microsoft/sarif-multitool merge *.sarif --recurse --output-directory=${{ github.workspace }}\output --output-file=MergeResult.sarif --merge-runs && dir ${{ github.workspace }}\output
+
+ - name: Upload SARIF to GitHub
+ uses: github/codeql-action/upload-sarif@v2
+ continue-on-error: true
+ with:
+ sarif_file: ${{ github.workspace }}\output\MergeResult.sarif
+ category: VS_SCA_WIN32_WINML_X86
diff --git a/.github/workflows/wheel.yaml b/.github/workflows/wheel.yaml
new file mode 100644
index 0000000000..b4f7b984ef
--- /dev/null
+++ b/.github/workflows/wheel.yaml
@@ -0,0 +1,47 @@
+name: CI && Release & Upload Wheel
+
+on:
+ workflow_dispatch:
+ push:
+ branches:
+ - main
+ pull_request:
+ branches:
+ - main
+
+jobs:
+ build_and_upload_wheel:
+ runs-on: linux-mk
+ container:
+ image: ghcr.io/quadric-io/tvm:devel
+ options: "--mount type=bind,source=${{ github.workspace }},target=/workspace"
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v3
+ - name: Build ONNX Runtime wheel
+ working-directory: /workspace
+ run: |
+ python3 -m pip install cmake --upgrade
+ ./build.sh --build_wheel --config Release --parallel ${{ github.event_name == 'pull_request' && ' ' || '--skip_tests'}} --skip_submodule_sync
+ wheel_path=$(find . -name '*.whl' | xargs readlink -f)
+ echo "wheel_path=$wheel_path" >> $GITHUB_ENV
+ - name: Count releases
+ id: count_releases
+ if: github.event_name != 'pull_request'
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ run: |
+ apt-get update && apt-get install curl jq -y
+ count=$(curl --request GET \
+ --url https://api.github.com/repos/${{ github.repository }}/releases \
+ --header "Authorization: Bearer $GITHUB_TOKEN" | jq length)
+ echo "count=$count" >> $GITHUB_ENV
+ - name: Create Release
+ if: github.event_name != 'pull_request'
+ uses: softprops/action-gh-release@v1
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ with:
+ tag_name: v${{ env.count }}
+ name: Release v${{ env.count }}
+ files: ${{ env.wheel_path }}
diff --git a/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml b/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml
index fa2b475fe9..b9de1b79e1 100644
--- a/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml
+++ b/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml
@@ -351,6 +351,31 @@ extends:
- script: |
dir $(Build.SourcesDirectory)\unzipped\runtimes\win-x64\_native
+ - task: EsrpCodeSigning@2
+ displayName: "Sign Nuget package"
+ inputs:
+ ConnectedServiceName: 'OnnxRuntime CodeSign 20190817'
+ FolderPath: $(Build.ArtifactStagingDirectory)
+ Pattern: '*.nupkg'
+ signConfigType: inlineSignParams
+ inlineOperation: |
+ [
+ {
+ "keyCode": "CP-401405",
+ "operationSetCode": "NuGetSign",
+ "parameters": [ ],
+ "toolName": "sign",
+ "toolVersion": "1.0"
+ },
+ {
+ "keyCode": "CP-401405",
+ "operationSetCode": "NuGetVerify",
+ "parameters": [ ],
+ "toolName": "sign",
+ "toolVersion": "1.0"
+ }
+ ]
+
- job: NuGet_Publishing
pool:
type: windows
diff --git a/.pipelines/nuget_config/x64/packages.config b/.pipelines/nuget_config/x64/packages.config
index 2ec9b577c1..8eef0b5bac 100644
--- a/.pipelines/nuget_config/x64/packages.config
+++ b/.pipelines/nuget_config/x64/packages.config
@@ -1,6 +1,6 @@

-
+
diff --git a/.pipelines/nuget_config/x86/packages.config b/.pipelines/nuget_config/x86/packages.config
index 6f9e3ef496..81f97948f1 100644
--- a/.pipelines/nuget_config/x86/packages.config
+++ b/.pipelines/nuget_config/x86/packages.config
@@ -1,6 +1,6 @@

-
+
diff --git a/.pipelines/windowsai-steps.yml b/.pipelines/windowsai-steps.yml
index 0b736da427..a29e2fe6c8 100644
--- a/.pipelines/windowsai-steps.yml
+++ b/.pipelines/windowsai-steps.yml
@@ -80,11 +80,11 @@ jobs:
# must call vsdevcmd first to add cmake to PATH
- script: |
- curl -O -L https://github.com/Kitware/CMake/releases/download/v3.24.3/cmake-3.24.3-windows-x86_64.zip
- 7z x cmake-3.24.3-windows-x86_64.zip
+ curl -O -L https://github.com/Kitware/CMake/releases/download/v3.26.3/cmake-3.26.3-windows-x86_64.zip
+ 7z x cmake-3.26.3-windows-x86_64.zip
set PYTHONHOME=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools
set PYTHONPATH=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools
- $(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 16 2019" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos $(BuildFlags) --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.19041.0 --cmake_path $(Build.BinariesDirectory)\cmake-3.24.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.24.3-windows-x86_64\bin\ctest.exe
+ $(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 16 2019" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos $(BuildFlags) --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.19041.0 --cmake_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\ctest.exe
workingDirectory: '$(Build.BinariesDirectory)'
displayName: 'Generate cmake config'
diff --git a/Package.swift b/Package.swift
index 7f8bfe0c3c..54609b104d 100644
--- a/Package.swift
+++ b/Package.swift
@@ -79,24 +79,10 @@ if let pod_archive_path = ProcessInfo.processInfo.environment["ORT_IOS_POD_LOCAL
package.targets.append(Target.binaryTarget(name: "onnxruntime", path: pod_archive_path))
} else {
- // When creating the release version:
- // - remove the fatalError
- // - uncomment the package.targets.append call
- // - update the major/minor/patch version info in the url
- // - insert the checksum info from the onnxruntime-ios-packaging-pipeline CI's 'Print ORT iOS Pod checksum'
- // stage output (or download the pod archive artifact from the CI and run `shasum -a 256 `
- // to manually calculate it).
- // The checksum length and chars should look something like
- // "c89cd106ff02eb3892243acd7c4f2bd8e68c2c94f2751b5e35f98722e10c042b"
- //
- // package.targets.append(
- // Target.binaryTarget(name: "onnxruntime",
- // url: "https://onnxruntimepackages.z14.web.core.windows.net/pod-archive-onnxruntime-c-.zip",
- // checksum: "Insert checksum here")
- // )
-
- fatalError("It is not valid to use a non-release branch from https://github.com/microsoft/onnxruntime.\n" +
- "Please use a release branch (e.g. rel-1.15.0), or build the ONNX Runtime iOS pod archive locally " +
- "and set the ORT_IOS_POD_LOCAL_PATH environment variable.\n" +
- "See Package.swift for more information on using a local pod archive.")
+ // ORT 1.15.0 release
+ package.targets.append(
+ Target.binaryTarget(name: "onnxruntime",
+ url: "https://onnxruntimepackages.z14.web.core.windows.net/pod-archive-onnxruntime-c-1.15.0.zip",
+ checksum: "9b41412329a73d7d298b1d94ab40ae9adb65cb84f132054073bc82515b4f5f82")
+ )
}
diff --git a/README_EPU.md b/README_EPU.md
new file mode 100644
index 0000000000..e42cbe176d
--- /dev/null
+++ b/README_EPU.md
@@ -0,0 +1,27 @@
+# The Quadric Version of onnxruntime
+
+This repository contains the a distribution of onnxruntime with additional operator quantization capabilities.
+
+
+## Prerequisites:
+- python 3.9
+- pip
+
+## Clone repository and build:
+```
+git clone --recursive https://github.com/quadric-io/onnxruntime onnxruntime
+cd onnxruntime
+# Install wheel
+pip3 install wheel
+# Build the python package
+./build.sh --build_wheel --config Release --parallel --compile_no_warning_as_error --skip_tests --skip_submodule_sync
+```
+
+## Install
+```
+# Find the wheel you just created
+$ find . -name '*.whl'
+./build/MacOS/Release/dist/onnxruntime-1.15.1-cp310-cp310-macosx_13_0_arm64.whl
+# Install it
+pip3 install ./build/MacOS/Release/dist/onnxruntime-1.15.1-cp310-cp310-macosx_13_0_arm64.whl
+```
diff --git a/ThirdPartyNotices.txt b/ThirdPartyNotices.txt
index b4d981d42d..e099840e64 100644
--- a/ThirdPartyNotices.txt
+++ b/ThirdPartyNotices.txt
@@ -5993,3 +5993,32 @@ https://github.com/tensorflow/tfjs
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
+
+——
+
+curl/curl
+
+https://github.com/curl
+
+COPYRIGHT AND PERMISSION NOTICE
+
+Copyright (C) Daniel Stenberg, , and many
+contributors, see the THANKS file.
+
+All rights reserved.
+
+Permission to use, copy, modify, and distribute this software for any purpose
+with or without fee is hereby granted, provided that the above copyright
+notice and this permission notice appear in all copies.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. IN
+NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
+DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
+OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
+OR OTHER DEALINGS IN THE SOFTWARE.
+
+Except as contained in this notice, the name of a copyright holder shall not
+be used in advertising or otherwise to promote the sale, use or other dealings
+in this Software without prior written authorization of the copyright holder.
\ No newline at end of file
diff --git a/VERSION_NUMBER b/VERSION_NUMBER
index 141f2e805b..ace44233b4 100644
--- a/VERSION_NUMBER
+++ b/VERSION_NUMBER
@@ -1 +1 @@
-1.15.0
+1.15.1
diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json
index 989756361b..4f4c1e7e74 100644
--- a/cgmanifests/generated/cgmanifest.json
+++ b/cgmanifests/generated/cgmanifest.json
@@ -112,7 +112,7 @@
"component": {
"type": "git",
"git": {
- "commitHash": "9b7bca2a723ff94edcd007d93b5d0cf1838591dc",
+ "commitHash": "a0d77f18516d2da7468a96b0de3b737266f23176",
"repositoryUrl": "https://github.com/onnx/onnx.git"
},
"comments": "git submodule at cmake/external/onnx"
@@ -288,16 +288,6 @@
"comments": "mp11"
}
},
- {
- "component": {
- "type": "git",
- "git": {
- "commitHash": "3b58938e025c41d2fcd89fa22028eefaa81a18ad",
- "repositoryUrl": "https://github.com/onnx/onnx.git"
- },
- "comments": "onnx"
- }
- },
{
"component": {
"type": "git",
@@ -447,6 +437,26 @@
},
"comments": "triton"
}
+ },
+ {
+ "component": {
+ "type": "git",
+ "git": {
+ "commitHash": "94142d8391c9791ec71c38336436319a2d4ac7a0",
+ "repositoryUrl": "https://github.com/microsoft/onnxruntime-extensions.git"
+ },
+ "comments": "extensions"
+ }
+ },
+ {
+ "component": {
+ "type": "git",
+ "git": {
+ "commitHash": "b16d1fa8ee567b52c09a0f89940b07d8491b881d",
+ "repositoryUrl": "https://github.com/curl/curl.git"
+ },
+ "comments": "curl"
+ }
}
]
}
diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt
index 0b2f980ab2..2e54063598 100644
--- a/cmake/CMakeLists.txt
+++ b/cmake/CMakeLists.txt
@@ -226,6 +226,7 @@ option(onnxruntime_BUILD_CACHE "onnxruntime build with cache" OFF)
cmake_dependent_option(MSVC_Z7_OVERRIDE "replacing /Zi and /ZI with /Z7 when using MSVC with CCache" ON "onnxruntime_BUILD_CACHE; MSVC" OFF)
option(onnxruntime_USE_AZURE "Build with azure inferencing support" OFF)
+option(onnxruntime_USE_LOCK_FREE_QUEUE "Build with lock-free task queue for threadpool." OFF)
# ENABLE_TRAINING includes all training functionality
# The following 2 entry points
@@ -272,6 +273,9 @@ if (onnxruntime_USE_ROCM)
file(GLOB rocm_cmake_components ${onnxruntime_ROCM_HOME}/lib/cmake/*)
list(APPEND CMAKE_PREFIX_PATH ${rocm_cmake_components})
+ # Force cmake to accept the configured HIP compiler. Because the configured CMAKE_PREFIX_PATH does not work during
+ # enable_language(HIP), we might need to move configuring of CMAKE_PREFIX_PATH to build.py (in the future).
+ set(CMAKE_HIP_COMPILER_FORCED ON)
enable_language(HIP)
# NOTE: Flags -mllvm -amdgpu-early-inline-all=true are critical for gpu kernel code performance. -mllvm passes the
@@ -740,7 +744,9 @@ if (onnxruntime_USE_AZURE)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_AZURE=1)
list(APPEND ONNXRUNTIME_PROVIDER_NAMES azure)
endif()
-
+if (onnxruntime_USE_LOCK_FREE_QUEUE)
+ add_compile_definitions(USE_LOCK_FREE_QUEUE)
+endif()
if (onnxruntime_ENABLE_LAZY_TENSOR)
# To support LazyTensor, ORT needs to call Python function from C/C++.
diff --git a/cmake/deps.txt b/cmake/deps.txt
index 6b7fb0c95f..b887418256 100644
--- a/cmake/deps.txt
+++ b/cmake/deps.txt
@@ -23,7 +23,7 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36
microsoft_wil;https://github.com/microsoft/wil/archive/5f4caba4e7a9017816e47becdd918fcc872039ba.zip;fd119887d0d17c37adf1fc227b054befa28158ad
mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41
mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.79.0.zip;c8f04e378535ededbe5af52c8f969d2dedbe73d5
-onnx;https://github.com/onnx/onnx/archive/3b58938e025c41d2fcd89fa22028eefaa81a18ad.zip;e0e5dda9eea5cd5ecae3bd8be86e477016b6be02
+onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.14.0.zip;3c6e43a36f94addc15afe939860127a1d74a9488
#use the last commit of 8.6-EA branch (https://github.com/onnx/onnx-tensorrt/commit/ba6a4fb34fdeaa3613bf981610c657e7b663a699)
onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/ba6a4fb34fdeaa3613bf981610c657e7b663a699.zip;5a474ed86e2c4ee4085d3daeff8222866e933dc0
protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa
@@ -48,4 +48,5 @@ b64;https://github.com/libb64/libb64/archive/refs/tags/v2.0.0.1.zip;815b6d31d50d
pthread;https://sourceforge.net/projects/pthreads4w/files/pthreads4w-code-v3.0.0.zip;3b9e417e4474c34542b76ad40529e396ac109fb4
triton;https://github.com/triton-inference-server/server/archive/refs/tags/v2.28.0.zip;4b305570aa1e889946e20e36050b6770e4108fee
# above are deps introduced by triton client, might remove after 1.14 release
-extensions;https://github.com/microsoft/onnxruntime-extensions/archive/81e7799c69044c745239202085eb0a98f102937b.zip;d53487035174a046628359289ad27aa0ac0380c9
+extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c
+curl;https://github.com/curl/curl/archive/refs/tags/curl-8_0_1.zip;b16d1fa8ee567b52c09a0f89940b07d8491b881d
\ No newline at end of file
diff --git a/cmake/external/dml.cmake b/cmake/external/dml.cmake
index 161db093a1..1ee3a06b41 100644
--- a/cmake/external/dml.cmake
+++ b/cmake/external/dml.cmake
@@ -41,7 +41,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config)
set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config)
get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE)
- set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.10.1)
+ set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.12.0)
# Restore nuget packages, which will pull down the DirectML redist package.
add_custom_command(
diff --git a/cmake/external/emsdk b/cmake/external/emsdk
index 0ab19024f0..b113f24842 160000
--- a/cmake/external/emsdk
+++ b/cmake/external/emsdk
@@ -1 +1 @@
-Subproject commit 0ab19024f08c6673a713e454ef8bd95e174c807f
+Subproject commit b113f24842c6e97fe3e352084db09a6e278593ae
diff --git a/cmake/external/onnx b/cmake/external/onnx
index 9b7bca2a72..a0d77f1851 160000
--- a/cmake/external/onnx
+++ b/cmake/external/onnx
@@ -1 +1 @@
-Subproject commit 9b7bca2a723ff94edcd007d93b5d0cf1838591dc
+Subproject commit a0d77f18516d2da7468a96b0de3b737266f23176
diff --git a/cmake/external/triton.cmake b/cmake/external/triton.cmake
index b24768bd89..41da407adf 100644
--- a/cmake/external/triton.cmake
+++ b/cmake/external/triton.cmake
@@ -44,13 +44,11 @@ if (WIN32)
vcpkg_install(re2)
vcpkg_install(boost-interprocess)
vcpkg_install(boost-stacktrace)
- vcpkg_install(zlib)
vcpkg_install(pthread)
vcpkg_install(b64)
add_dependencies(getb64 getpthread)
- add_dependencies(getpthread getzlib)
- add_dependencies(getzlib getboost-stacktrace)
+ add_dependencies(getpthread getboost-stacktrace)
add_dependencies(getboost-stacktrace getboost-interprocess)
add_dependencies(getboost-interprocess getre2)
add_dependencies(getre2 getrapidjson)
@@ -59,11 +57,11 @@ if (WIN32)
ExternalProject_Add(triton
GIT_REPOSITORY https://github.com/triton-inference-server/client.git
- GIT_TAG r22.12
+ GIT_TAG r23.05
PREFIX triton
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-src
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-build
- CMAKE_ARGS -DVCPKG_TARGET_TRIPLET=${onnxruntime_target_platform}-windows -DCMAKE_TOOLCHAIN_FILE=${VCPKG_SRC}/scripts/buildsystems/vcpkg.cmake -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON
+ CMAKE_ARGS -DVCPKG_TARGET_TRIPLET=${onnxruntime_target_platform}-windows -DCMAKE_TOOLCHAIN_FILE=${VCPKG_SRC}/scripts/buildsystems/vcpkg.cmake -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON -DTRITON_ENABLE_ZLIB=OFF
INSTALL_COMMAND ""
UPDATE_COMMAND "")
@@ -77,7 +75,9 @@ else()
PREFIX rapidjson
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/rapidjson-src
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/rapidjson-build
- CMAKE_ARGS -DRAPIDJSON_BUILD_TESTS=OFF -DRAPIDJSON_BUILD_DOC=OFF -DRAPIDJSON_BUILD_EXAMPLES=OFF)
+ CMAKE_ARGS -DRAPIDJSON_BUILD_TESTS=OFF -DRAPIDJSON_BUILD_DOC=OFF -DRAPIDJSON_BUILD_EXAMPLES=OFF
+ INSTALL_COMMAND ""
+ UPDATE_COMMAND "")
ExternalProject_Get_Property(rapidjson source_dir)
set(RAPIDJSON_INCLUDE_DIR ${source_dir}/include)
@@ -85,16 +85,14 @@ else()
ExternalProject_Add(triton
GIT_REPOSITORY https://github.com/triton-inference-server/client.git
- GIT_TAG r22.12
+ GIT_TAG r23.05
PREFIX triton
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-src
BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps/triton-build
- CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON
+ CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=binary -DTRITON_ENABLE_CC_HTTP=ON -DTRITON_ENABLE_ZLIB=OFF
INSTALL_COMMAND ""
UPDATE_COMMAND "")
- add_dependencies(triton rapidjson)
-
endif() #if (WIN32)
ExternalProject_Get_Property(triton SOURCE_DIR)
diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake
index 02861458c9..92f59b8326 100644
--- a/cmake/onnxruntime.cmake
+++ b/cmake/onnxruntime.cmake
@@ -28,9 +28,9 @@ macro(get_mobile_api_headers _HEADERS)
)
if (onnxruntime_ENABLE_TRAINING_APIS)
- list(APPEND ${_HEADERS} "${REPO_ROOT/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h}")
- list(APPEND ${_HEADERS} "${REPO_ROOT/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h}")
- list(APPEND ${_HEADERS} "${REPO_ROOT/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline_api.h}")
+ list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h")
+ list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h")
+ list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h")
endif()
# need to add header files for enabled EPs
diff --git a/cmake/onnxruntime_config.h.in b/cmake/onnxruntime_config.h.in
index 9fc46018da..4f5125569c 100644
--- a/cmake/onnxruntime_config.h.in
+++ b/cmake/onnxruntime_config.h.in
@@ -21,5 +21,5 @@
#cmakedefine HAS_FORMAT_TRUNCATION
#cmakedefine HAS_BITWISE_INSTEAD_OF_LOGICAL
#cmakedefine HAS_REALLOCARRAY
-#cmakedefine ORT_VERSION "@ORT_VERSION@"
-#cmakedefine ORT_BUILD_INFO "@ORT_BUILD_INFO@"
+#cmakedefine ORT_VERSION u8"@ORT_VERSION@"
+#cmakedefine ORT_BUILD_INFO u8"@ORT_BUILD_INFO@"
diff --git a/cmake/onnxruntime_framework.cmake b/cmake/onnxruntime_framework.cmake
index 5c947a52b7..1254cd16f2 100644
--- a/cmake/onnxruntime_framework.cmake
+++ b/cmake/onnxruntime_framework.cmake
@@ -40,13 +40,13 @@ onnxruntime_add_static_library(onnxruntime_framework ${onnxruntime_framework_src
if (onnxruntime_USE_AZURE)
add_dependencies(onnxruntime_framework triton)
- target_include_directories(onnxruntime_framework PRIVATE ${TRITON_BIN}/include)
+ target_include_directories(onnxruntime_framework PRIVATE ${TRITON_BIN}/include ${TRITON_THIRD_PARTY}/curl/include)
link_directories(${TRITON_BIN}/lib ${TRITON_BIN}/lib64 ${TRITON_THIRD_PARTY}/curl/lib ${TRITON_THIRD_PARTY}/curl/lib64)
if (WIN32)
link_directories(${VCPKG_SRC}/installed/${onnxruntime_target_platform}-windows/lib)
- target_link_libraries(onnxruntime_framework PRIVATE libcurl httpclient_static ws2_32 crypt32 Wldap32 zlib)
+ target_link_libraries(onnxruntime_framework PRIVATE libcurl httpclient_static ws2_32 crypt32 Wldap32)
else()
diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake
index 2839713406..0daa1b8a3d 100644
--- a/cmake/onnxruntime_providers.cmake
+++ b/cmake/onnxruntime_providers.cmake
@@ -594,7 +594,7 @@ if (onnxruntime_USE_DNNL)
add_dependencies(onnxruntime_providers_dnnl onnxruntime_providers_shared project_dnnl ${onnxruntime_EXTERNAL_DEPENDENCIES})
target_include_directories(onnxruntime_providers_dnnl PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} ${DNNL_INCLUDE_DIR} ${DNNL_OCL_INCLUDE_DIR})
# ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found
- target_link_libraries(onnxruntime_providers_dnnl PRIVATE dnnl ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 ${ABSEIL_LIBS} ${GSL_TARGET})
+ target_link_libraries(onnxruntime_providers_dnnl PRIVATE dnnl ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 ${ABSEIL_LIBS} ${GSL_TARGET} safeint_interface)
install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/dnnl DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers)
set_target_properties(onnxruntime_providers_dnnl PROPERTIES FOLDER "ONNXRuntime")
set_target_properties(onnxruntime_providers_dnnl PROPERTIES LINKER_LANGUAGE CXX)
diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake
index 56a7a4b350..9347be180d 100644
--- a/cmake/onnxruntime_unittests.cmake
+++ b/cmake/onnxruntime_unittests.cmake
@@ -848,7 +848,7 @@ if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
endif()
if (onnxruntime_BUILD_WEBASSEMBLY)
set_target_properties(onnxruntime_test_all PROPERTIES LINK_DEPENDS ${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js)
- set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s STACK_SIZE=1048576 -s ALLOW_MEMORY_GROWTH=1 --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1")
+ set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1")
if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS)
set_property(TARGET onnxruntime_test_all APPEND_STRING PROPERTY LINK_FLAGS " -s USE_PTHREADS=1 -s PROXY_TO_PTHREAD=1")
endif()
diff --git a/csharp/OnnxRuntime.CSharp.proj b/csharp/OnnxRuntime.CSharp.proj
index 5473246e8b..0288d752d8 100644
--- a/csharp/OnnxRuntime.CSharp.proj
+++ b/csharp/OnnxRuntime.CSharp.proj
@@ -16,6 +16,7 @@ CMake creates a target to this project
nuget
x64
false
+
false
None
@@ -77,6 +78,7 @@ CMake creates a target to this project
$([System.DateTime]::UtcNow.ToString(yyyyMMdd))
$([System.DateTime]::UtcNow.ToString(hhmm))
@(MajorVersionNumber)
+ $(PackageVersion)$(ReleaseVersionSuffix)
$(PackageVersion)
$(PackageVersion)-dev-$(CurrentDate)-$(CurrentTime)-$(GitCommitHashShort)
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj
index 78083a8cc1..ad468b0c6d 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj
@@ -41,6 +41,7 @@
net6.0;net6.0-android;net6.0-ios;net6.0-macos
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
index 14e93dc05f..5fcddf7d0c 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
@@ -286,6 +286,7 @@ public struct OrtApi
public IntPtr CastTypeInfoToOptionalTypeInfo;
public IntPtr GetOptionalContainedTypeInfo;
public IntPtr GetResizedStringTensorElementBuffer;
+ public IntPtr KernelContext_GetAllocator;
}
internal static class NativeMethods
@@ -1654,7 +1655,6 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
public static DOrtModelMetadataLookupCustomMetadataMap OrtModelMetadataLookupCustomMetadataMap;
-
///
/// Frees ModelMetadata instance
///
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs
index 2df16ccae0..1a03338298 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs
@@ -276,7 +276,7 @@ public static OrtEnv Instance()
///
///
/// if the singleton has already been created
- public static OrtEnv CreateInstanceWithOptions(EnvironmentCreationOptions options)
+ public static OrtEnv CreateInstanceWithOptions(ref EnvironmentCreationOptions options)
{
// Non-thread safe, best effort hopefully helpful check.
// Environment is usually created once per process, so this should be fine.
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
index 6bbb159a3d..30951bae3f 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
@@ -375,10 +375,10 @@ public IntPtr Appender(IntPtr handle, IntPtr[] optKeys, IntPtr[] optValues, UInt
/// Optional key/value pairs to specify execution provider options.
public void AppendExecutionProvider(string providerName, Dictionary providerOptions = null)
{
- if (providerName != "SNPE" && providerName != "XNNPACK" && providerName != "QNN")
+ if (providerName != "SNPE" && providerName != "XNNPACK" && providerName != "QNN" && providerName != "AZURE")
{
throw new NotSupportedException(
- "Only QNN, SNPE and XNNPACK execution providers can be enabled by this method.");
+ "Only QNN, SNPE, XNNPACK and AZURE execution providers can be enabled by this method.");
}
if (providerOptions == null)
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs
index 5fb983ab37..83d114b535 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs
@@ -401,21 +401,22 @@ public void FromBuffer(FixedBufferOnnxValue buffer)
throw new ArgumentException(errorMessage);
}
- IntPtr numElementsTrainingOnly = IntPtr.Zero;
- NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out numElementsTrainingOnly));
-
- UIntPtr bufferSize = UIntPtr.Zero;
- NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, true));
- if ((long)bufferSize.ToUInt64() == numElementsTrainingOnly.ToInt64())
+ // Here buffer size represents the number of elements in the buffer
+ IntPtr bufferSize = IntPtr.Zero;
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out bufferSize));
+
+ // OrtGetParametersSize returns the total number of elements in the model's parameters.
+ UIntPtr numElementsTrainingOnly = UIntPtr.Zero;
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElementsTrainingOnly, true));
+ if (bufferSize.ToInt64() == (long)numElementsTrainingOnly.ToUInt64())
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, true));
return;
}
- IntPtr numElements = IntPtr.Zero;
- bufferSize = UIntPtr.Zero;
- NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, false));
- if ((long)bufferSize.ToUInt64() != numElements.ToInt64())
+ UIntPtr numElements = UIntPtr.Zero;
+ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, false));
+ if (bufferSize.ToInt64() != (long)numElements.ToUInt64())
{
string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString();
throw new ArgumentException(errorMessage);
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs
index 63ddf51116..8a94b199ff 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs
@@ -80,7 +80,7 @@ public void TestUpdatingEnvWithCustomLogLevel()
logLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL
};
- ortEnvInstance = OrtEnv.CreateInstanceWithOptions(envOptions);
+ ortEnvInstance = OrtEnv.CreateInstanceWithOptions(ref envOptions);
Assert.True(OrtEnv.IsCreated);
Assert.Equal(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, ortEnvInstance.EnvLogLevel);
@@ -92,7 +92,7 @@ public void TestUpdatingEnvWithCustomLogLevel()
logId = "CSharpOnnxRuntimeTestLogid"
};
- ortEnvInstance = OrtEnv.CreateInstanceWithOptions(envOptions);
+ ortEnvInstance = OrtEnv.CreateInstanceWithOptions(ref envOptions);
Assert.Equal(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING, ortEnvInstance.EnvLogLevel);
// Change and see if this takes effect
@@ -118,7 +118,7 @@ public void TestUpdatingEnvWithThreadingOptions()
};
// Make sure we start anew
- var env = OrtEnv.CreateInstanceWithOptions(envOptions);
+ var env = OrtEnv.CreateInstanceWithOptions(ref envOptions);
Assert.True(OrtEnv.IsCreated);
}
}
@@ -165,7 +165,7 @@ public void TesEnvWithCustomLogger()
LoggingInvokes = 0;
- var env = OrtEnv.CreateInstanceWithOptions(envOptions);
+ var env = OrtEnv.CreateInstanceWithOptions(ref envOptions);
Assert.True(OrtEnv.IsCreated);
var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx");
@@ -199,7 +199,7 @@ public void TestEnvWithCustomLoggerAndThredingOptions()
LoggingInvokes = 0;
- var env = OrtEnv.CreateInstanceWithOptions(envOptions);
+ var env = OrtEnv.CreateInstanceWithOptions(ref envOptions);
Assert.True(OrtEnv.IsCreated);
var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx");
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 8427dfaef4..d4b2310fa1 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -955,6 +955,7 @@ Do not modify directly.*
|||7+|**T** = tensor(float), tensor(float16)
**T1** = tensor(bool)|
|GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)|
|||12+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)|
+|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(float), tensor(float16)|
|HardSigmoid|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float), tensor(float16)|
|Hardmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float), tensor(float16)|
|||11+|**T** = tensor(float), tensor(float16)|
@@ -1101,7 +1102,8 @@ Do not modify directly.*
|||11+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float), tensor(float16)|
|||10+|**T** = tensor(float), tensor(float16)|
|ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|10+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)|
+|RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)|
+|||10+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)|
|Round|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(float), tensor(float16)|
|STFT|*in* signal:**T1**
*in* frame_step:**T2**
*in* window:**T1**
*in* frame_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int32), tensor(int64)|
|ScaledTanh|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
@@ -1194,6 +1196,7 @@ Do not modify directly.*
|FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)|
+|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
|NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
diff --git a/docs/python/README.rst b/docs/python/README.rst
index f5f162b1d4..e47f2ad19d 100644
--- a/docs/python/README.rst
+++ b/docs/python/README.rst
@@ -8,6 +8,11 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime mtx(spin_lock_);
+#else
std::lock_guard lock(mutex_);
+#endif
unsigned back = back_.load(std::memory_order_relaxed);
Elem& e = array_[(back - 1) & kMask];
ElemState s = e.state.load(std::memory_order_relaxed);
@@ -469,7 +474,11 @@ class RunQueue {
// with w_idx. Typically the tag will be a per-thread ID to distinguish work
// submitted from different threads.
PushResult PushBackWithTag(Work w, Tag tag, unsigned& w_idx) {
+#ifdef USE_LOCK_FREE_QUEUE
+ std::lock_guard mtx(spin_lock_);
+#else
std::lock_guard lock(mutex_);
+#endif
unsigned back = back_.load(std::memory_order_relaxed);
w_idx = (back - 1) & kMask;
Elem& e = array_[w_idx];
@@ -490,7 +499,11 @@ class RunQueue {
Work PopBack() {
if (Empty())
return Work();
+#ifdef USE_LOCK_FREE_QUEUE
+ std::lock_guard mtx(spin_lock_);
+#else
std::lock_guard lock(mutex_);
+#endif
unsigned back;
Elem* e;
ElemState s;
@@ -532,7 +545,11 @@ class RunQueue {
bool RevokeWithTag(Tag tag, unsigned w_idx) {
bool revoked = false;
+#ifdef USE_LOCK_FREE_QUEUE
+ std::lock_guard mtx(spin_lock_);
+#else
std::lock_guard lock(mutex_);
+#endif
Elem& e = array_[w_idx];
ElemState s = e.state.load(std::memory_order_relaxed);
@@ -604,7 +621,11 @@ class RunQueue {
Work w;
};
+#ifdef USE_LOCK_FREE_QUEUE
+ OrtSpinLock spin_lock_;
+#else
OrtMutex mutex_;
+#endif
// Low log(kSize) + 1 bits in front_ and back_ contain rolling index of
// front/back, respectively. The remaining bits contain modification counters
diff --git a/include/onnxruntime/core/platform/ort_spin_lock.h b/include/onnxruntime/core/platform/ort_spin_lock.h
new file mode 100644
index 0000000000..db80abe1ee
--- /dev/null
+++ b/include/onnxruntime/core/platform/ort_spin_lock.h
@@ -0,0 +1,35 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+#pragma once
+#include "core/common/spin_pause.h"
+#include
+
+namespace onnxruntime {
+/*
+OrtSpinLock implemented mutex semantic "lock-freely",
+calling thread will not be put to sleep on blocked,
+which reduces cpu usage on context switching.
+*/
+struct OrtSpinLock {
+ using LockState = enum { Locked = 0,
+ Unlocked };
+
+ void lock() noexcept {
+ LockState state = Unlocked;
+ while (!state_.compare_exchange_weak(state, Locked, std::memory_order_acq_rel, std::memory_order_relaxed)) {
+ state = Unlocked;
+ concurrency::SpinPause(); // pause and retry
+ }
+ }
+ bool try_lock() noexcept {
+ LockState state = Unlocked;
+ return state_.compare_exchange_weak(state, Locked, std::memory_order_acq_rel, std::memory_order_relaxed);
+ }
+ void unlock() noexcept {
+ state_.store(Unlocked, std::memory_order_release);
+ }
+
+ private:
+ std::atomic state_{Unlocked};
+};
+} // namespace onnxruntime
\ No newline at end of file
diff --git a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h
index 18dcfeea68..084e66fd81 100644
--- a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h
+++ b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h
@@ -14,21 +14,21 @@
/// User can only get the instance of OrtCUDAProviderOptionsV2 via CreateCUDAProviderOptions.
///
struct OrtCUDAProviderOptionsV2 {
- int device_id; // cuda device id.
- int has_user_compute_stream; // indicator of user specified CUDA compute stream.
- void* user_compute_stream; // user specified CUDA compute stream.
- int do_copy_in_default_stream; // flag specifying if the default stream is to be used for copying.
- OrtCudnnConvAlgoSearch cudnn_conv_algo_search; // cudnn algo search enum.
- size_t gpu_mem_limit; // BFC Arena memory limit for CUDA.
- // (will be overridden by contents of `default_memory_arena_cfg` is it exists)
- onnxruntime::ArenaExtendStrategy arena_extend_strategy; // BFC Arena extension strategy.
- // (will be overridden by contents of `default_memory_arena_cfg` is it exists)
- OrtArenaCfg* default_memory_arena_cfg; // BFC Arena config flags.
- int cudnn_conv_use_max_workspace; // flag specifying if maximum workspace can be used in cudnn conv algo search.
- int enable_cuda_graph; // flag specifying if the CUDA graph is to be captured for the model.
- int cudnn_conv1d_pad_to_nc1d; // flag specifying if pad Conv1D's input [N,C,D] to [N,C,1,D] or [N,C,D,1].
- int tunable_op_enable; // flag specifying if TunableOp is enabled.
- int tunable_op_tuning_enable; // flag specifying if TunableOp is enabled for tuning, this relies on TunableOp is enabled.
- int enable_skip_layer_norm_strict_mode; // flag specifying if SkipLayerNorm is in strict mode. If true, use LayerNormalization kernel.
- // The strict mode has better accuracy but lower performance.
+ int device_id = 0; // cuda device id.
+ int has_user_compute_stream = 0; // indicator of user specified CUDA compute stream.
+ void* user_compute_stream = nullptr; // user specified CUDA compute stream.
+ int do_copy_in_default_stream = 1; // flag specifying if the default stream is to be used for copying.
+ OrtCudnnConvAlgoSearch cudnn_conv_algo_search = OrtCudnnConvAlgoSearchExhaustive; // cudnn algo search enum.
+ size_t gpu_mem_limit = std::numeric_limits::max(); // BFC Arena memory limit for CUDA.
+ // (will be overridden by contents of `default_memory_arena_cfg` is it exists)
+ onnxruntime::ArenaExtendStrategy arena_extend_strategy = onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo; // BFC Arena extension strategy.
+ // (will be overridden by contents of `default_memory_arena_cfg` is it exists)
+ OrtArenaCfg* default_memory_arena_cfg = nullptr; // BFC Arena config flags.
+ int cudnn_conv_use_max_workspace = 1; // flag specifying if maximum workspace can be used in cudnn conv algo search.
+ int enable_cuda_graph = 0; // flag specifying if the CUDA graph is to be captured for the model.
+ int cudnn_conv1d_pad_to_nc1d = 0; // flag specifying if pad Conv1D's input [N,C,D] to [N,C,1,D] or [N,C,D,1].
+ int tunable_op_enable = 0; // flag specifying if TunableOp is enabled.
+ int tunable_op_tuning_enable = 0; // flag specifying if TunableOp is enabled for tuning, this relies on TunableOp is enabled.
+ int enable_skip_layer_norm_strict_mode = 0; // flag specifying if SkipLayerNorm is in strict mode. If true, use LayerNormalization kernel.
+ // The strict mode has better accuracy but lower performance.
};
diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
index 3bd6dac54b..600e255bcd 100644
--- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
+++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
@@ -36,7 +36,7 @@ struct OrtTensorRTProviderOptionsV2 {
int trt_detailed_build_log; // Enable detailed build step logging on TensorRT EP with timing for each engine build. Default 0 = false, nonzero = true
int trt_build_heuristics_enable; // Build engine using heuristics to reduce build time. Default 0 = false, nonzero = true
int trt_sparsity_enable; // Control if sparsity can be used by TRT. Default 0 = false, 1 = true
- int trt_builder_optimization_level; // Set the builder optimization level. WARNING: levels below 2 do not guarantee good engine performance, but greatly improve build time. Default 2, valid range [0-4]
+ int trt_builder_optimization_level; // Set the builder optimization level. WARNING: levels below 3 do not guarantee good engine performance, but greatly improve build time. Default 3, valid range [0-5]
int trt_auxiliary_streams; // Set maximum number of auxiliary streams per inference stream. Setting this value to 0 will lead to optimal memory usage. Default -1 = heuristics
const char* trt_tactic_sources; // pecify the tactics to be used by adding (+) or removing (-) tactics from the default
// tactic sources (default = all available tactics) e.g. "-CUDNN,+CUBLAS" available keys: "CUBLAS"|"CUBLAS_LT"|"CUDNN"|"EDGE_MASK_CONVOLUTIONS"
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index 7f8cc059a6..d215b12157 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -95,6 +95,8 @@ extern "C" {
#define ORTCHAR_T char
#endif
+/// ORTCHAR_T, ORT_TSTR are reserved specifically for path handling.
+/// All other strings are UTF-8 encoded, use char and std::string
#ifndef ORT_TSTR
#ifdef _WIN32
#define ORT_TSTR(X) L##X
@@ -627,11 +629,19 @@ struct OrtApiBase {
* \param[in] version Must be ::ORT_API_VERSION
* \return The ::OrtApi for the version requested, nullptr will be returned if this version is unsupported, for example when using a runtime
* older than the version created with this header file.
+ *
+ * One can call GetVersionString() to get the version of the Onnxruntime library for logging
+ * and error reporting purposes.
*/
const OrtApi*(ORT_API_CALL* GetApi)(uint32_t version)NO_EXCEPTION;
- const ORTCHAR_T*(ORT_API_CALL* GetVersionString)(void)NO_EXCEPTION; ///< Returns a null terminated string of the version of the Onnxruntime library (eg: "1.8.1")
- const ORTCHAR_T*(ORT_API_CALL* GetBuildInfoString)(void)NO_EXCEPTION; ///< Returns a null terminated string of the build info including git info and cxx flags
+
+ /** \brief Returns a null terminated string of the version of the Onnxruntime library (eg: "1.8.1")
+ *
+ * \return UTF-8 encoded version string. Do not deallocate the returned buffer.
+ */
+ const char*(ORT_API_CALL* GetVersionString)(void)NO_EXCEPTION;
};
+
typedef struct OrtApiBase OrtApiBase;
/** \brief The Onnxruntime library's entry point to access the C API
@@ -4192,6 +4202,14 @@ struct OrtApi {
* \since Version 1.15.
*/
ORT_API2_STATUS(KernelContext_GetAllocator, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _Outptr_ OrtAllocator** out);
+
+ /** \brief Returns a null terminated string of the build info including git info and cxx flags
+ *
+ * \return UTF-8 encoded version string. Do not deallocate the returned buffer.
+ *
+ * \since Version 1.15.
+ */
+ const char*(ORT_API_CALL* GetBuildInfoString)(void);
};
/*
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index df007b7af4..2d5e1a9bdd 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -125,14 +125,14 @@ inline const OrtApi& GetApi() noexcept { return *Global::api_; }
/// This function returns the onnxruntime version string
///
/// version string major.minor.rev
-std::basic_string GetVersionString();
+std::string GetVersionString();
///
/// This function returns the onnxruntime build information: including git branch,
/// git commit id, build type(Debug/Release/RelWithDebInfo) and cmake cpp flags.
///
/// string
-std::basic_string GetBuildInfoString();
+std::string GetBuildInfoString();
///
/// This is a C++ wrapper for OrtApi::GetAvailableProviders() and
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
index 3e7cf721cd..b72bcd35fa 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
@@ -1987,14 +1987,12 @@ inline void CustomOpApi::ReleaseKernelInfo(_Frees_ptr_opt_ OrtKernelInfo* info_c
api_.ReleaseKernelInfo(info_copy);
}
-inline std::basic_string GetVersionString() {
- std::basic_string result = OrtGetApiBase()->GetVersionString();
- return result;
+inline std::string GetVersionString() {
+ return OrtGetApiBase()->GetVersionString();
}
-inline std::basic_string GetBuildInfoString() {
- std::basic_string result = OrtGetApiBase()->GetBuildInfoString();
- return result;
+inline std::string GetBuildInfoString() {
+ return GetApi().GetBuildInfoString();
}
inline std::vector GetAvailableProviders() {
diff --git a/java/src/main/java/ai/onnxruntime/OrtProvider.java b/java/src/main/java/ai/onnxruntime/OrtProvider.java
index cb35bf4f50..0da9487c67 100644
--- a/java/src/main/java/ai/onnxruntime/OrtProvider.java
+++ b/java/src/main/java/ai/onnxruntime/OrtProvider.java
@@ -23,7 +23,8 @@ public enum OrtProvider {
ARM_NN("ArmNNExecutionProvider"),
ROCM("ROCMExecutionProvider"),
CORE_ML("CoreMLExecutionProvider"),
- XNNPACK("XnnpackExecutionProvider");
+ XNNPACK("XnnpackExecutionProvider"),
+ AZURE("AzureExecutionProvider");
private static final Map valueMap = new HashMap<>(values().length);
diff --git a/java/src/main/native/ai_onnxruntime_OnnxRuntime.c b/java/src/main/native/ai_onnxruntime_OnnxRuntime.c
index 5165f3499e..659f34e1fb 100644
--- a/java/src/main/native/ai_onnxruntime_OnnxRuntime.c
+++ b/java/src/main/native/ai_onnxruntime_OnnxRuntime.c
@@ -76,13 +76,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxRuntime_getAvailableProvi
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OnnxRuntime_initialiseVersion
(JNIEnv * jniEnv, jclass clazz) {
(void)clazz; // required JNI parameter not needed by functions which don't access their host class.
- const ORTCHAR_T* version = OrtGetApiBase()->GetVersionString();
+ const char* version = ORT_VERSION;
assert(version != NULL);
-#ifdef _WIN32
- jsize len = (jsize)(wcslen(version));
- jstring versionStr = (*jniEnv)->NewString(jniEnv, (const jchar*)version, len);
-#else
- jstring versionStr = (*jniEnv)->NewStringUTF(jniEnv, version);
-#endif
- return versionStr;
+ return (*jniEnv)->NewStringUTF(jniEnv, version);
}
diff --git a/js/common/package-lock.json b/js/common/package-lock.json
index 1d956b5888..49816816a5 100644
--- a/js/common/package-lock.json
+++ b/js/common/package-lock.json
@@ -1,12 +1,12 @@
{
"name": "onnxruntime-common",
- "version": "1.15.0",
+ "version": "1.15.1",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "onnxruntime-common",
- "version": "1.15.0",
+ "version": "1.15.1",
"license": "MIT",
"devDependencies": {
"typedoc": "^0.23.22"
diff --git a/js/common/package.json b/js/common/package.json
index 39e9e356e3..357fd7b099 100644
--- a/js/common/package.json
+++ b/js/common/package.json
@@ -8,7 +8,7 @@
},
"author": "fs-eire",
"module": "dist/lib/index.js",
- "version": "1.15.0",
+ "version": "1.15.1",
"jsdelivr": "dist/ort-common.min.js",
"scripts": {
"prepare": "tsc && webpack"
diff --git a/js/node/package-lock.json b/js/node/package-lock.json
index 688e434121..8c9b3f049b 100644
--- a/js/node/package-lock.json
+++ b/js/node/package-lock.json
@@ -1,12 +1,12 @@
{
"name": "onnxruntime-node",
- "version": "1.15.0",
+ "version": "1.15.1",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "onnxruntime-node",
- "version": "1.15.0",
+ "version": "1.15.1",
"license": "MIT",
"os": [
"win32",
@@ -28,7 +28,8 @@
}
},
"../common": {
- "version": "1.15.0",
+ "name": "onnxruntime-common",
+ "version": "1.15.1",
"license": "MIT",
"devDependencies": {
"typedoc": "^0.23.22"
diff --git a/js/node/package.json b/js/node/package.json
index 8bad9d2a21..33d83edb9a 100644
--- a/js/node/package.json
+++ b/js/node/package.json
@@ -13,7 +13,7 @@
3
]
},
- "version": "1.15.0",
+ "version": "1.15.1",
"dependencies": {
"onnxruntime-common": "file:../common"
},
diff --git a/js/package-lock.json b/js/package-lock.json
index deb97d1a07..4163656b34 100644
--- a/js/package-lock.json
+++ b/js/package-lock.json
@@ -2403,8 +2403,8 @@
}
},
"node_modules/fastq": {
- "version": "1.15.0",
- "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.15.0.tgz",
+ "version": "1.15.1",
+ "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.15.1.tgz",
"integrity": "sha512-wBrocU2LCXXa+lWBt8RoIRD89Fi8OdABODa/kEnyeyjS5aZO5/GNvI5sEINADqP/h8M29UHTHUb53sUu5Ihqdw==",
"dev": true,
"dependencies": {
@@ -7532,8 +7532,8 @@
"dev": true
},
"fastq": {
- "version": "1.15.0",
- "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.15.0.tgz",
+ "version": "1.15.1",
+ "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.15.1.tgz",
"integrity": "sha512-wBrocU2LCXXa+lWBt8RoIRD89Fi8OdABODa/kEnyeyjS5aZO5/GNvI5sEINADqP/h8M29UHTHUb53sUu5Ihqdw==",
"dev": true,
"requires": {
diff --git a/js/react_native/package.json b/js/react_native/package.json
index 178b632bce..4a6efe6478 100644
--- a/js/react_native/package.json
+++ b/js/react_native/package.json
@@ -38,7 +38,7 @@
"registry": "https://registry.npmjs.org/"
},
"source": "lib/index",
- "version": "1.15.0",
+ "version": "1.15.1",
"main": "dist/commonjs/index",
"homepage": "https://github.com/microsoft/onnxruntime/blob/main/js/react_native/README.md",
"files": [
diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock
index d5925f65d2..c1ecd6754e 100644
--- a/js/react_native/yarn.lock
+++ b/js/react_native/yarn.lock
@@ -5197,7 +5197,7 @@ onetime@^5.1.0, onetime@^5.1.2:
mimic-fn "^2.1.0"
"onnxruntime-common@file:../common":
- version "1.15.0"
+ version "1.15.1"
open@^6.2.0:
version "6.4.0"
diff --git a/js/web/.gitignore b/js/web/.gitignore
index 438068ff32..5b9e31ad5e 100644
--- a/js/web/.gitignore
+++ b/js/web/.gitignore
@@ -15,6 +15,8 @@ test/**/*.js.map
script/**/*.js
script/**/*.js.map
+!/types.d.ts
+
lib/wasm/binding/**/*.wasm
!lib/wasm/binding/**/*.d.ts
diff --git a/js/web/.npmignore b/js/web/.npmignore
index e21a906c9a..8e08db5917 100644
--- a/js/web/.npmignore
+++ b/js/web/.npmignore
@@ -4,8 +4,7 @@
/dist/**/*.report.html
-/types/**/*.d.ts
-!/types/lib/**/*.d.ts
+/types/
karma.conf.js
tsconfig.json
diff --git a/js/web/docs/operators.md b/js/web/docs/operators.md
index 115884cb9d..de84134ddb 100644
--- a/js/web/docs/operators.md
+++ b/js/web/docs/operators.md
@@ -19,7 +19,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [Asinh](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Asinh) | |
| [Atan](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Atan) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Atan-7) |
| [Atanh](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Atanh) | |
-| [AveragePool](https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool) | [7-9](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#AveragePool-7), [10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#AveragePool-10), [11+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#AveragePool-11) |
+| [AveragePool](https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool) | [7-9](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#AveragePool-7), [10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#AveragePool-10), [11-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#AveragePool-11), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#AveragePool-19) |
| [BatchNormalization](https://github.com/onnx/onnx/blob/main/docs/Operators.md#BatchNormalization) | [7-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#BatchNormalization-7), [9-13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#BatchNormalization-9), [14](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#BatchNormalization-14), [15+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#BatchNormalization-15) |
| [Bernoulli](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Bernoulli) | |
| [BitShift](https://github.com/onnx/onnx/blob/main/docs/Operators.md#BitShift) | |
@@ -28,7 +28,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [BitwiseOr](https://github.com/onnx/onnx/blob/main/docs/Operators.md#BitwiseOr) | |
| [BitwiseXor](https://github.com/onnx/onnx/blob/main/docs/Operators.md#BitwiseXor) | |
| [BlackmanWindow](https://github.com/onnx/onnx/blob/main/docs/Operators.md#BlackmanWindow) | |
-| [Cast](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast) | [6-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-6), [9-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-9), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-13) |
+| [Cast](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast) | [6-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-6), [9-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-9), [13-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-13), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-19) |
| [CastLike](https://github.com/onnx/onnx/blob/main/docs/Operators.md#CastLike) | |
| [Ceil](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Ceil) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Ceil-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Ceil-13) |
| [Celu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Celu) | |
@@ -47,6 +47,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [Cosh](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cosh) | |
| [CumSum](https://github.com/onnx/onnx/blob/main/docs/Operators.md#CumSum) | |
| [DFT](https://github.com/onnx/onnx/blob/main/docs/Operators.md#DFT) | |
+| [DeformConv](https://github.com/onnx/onnx/blob/main/docs/Operators.md#DeformConv) | |
| [DepthToSpace](https://github.com/onnx/onnx/blob/main/docs/Operators.md#DepthToSpace) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#DepthToSpace-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#DepthToSpace-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#DepthToSpace-13) |
| [DequantizeLinear](https://github.com/onnx/onnx/blob/main/docs/Operators.md#DequantizeLinear) | |
| [Det](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Det) | |
@@ -55,7 +56,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [DynamicQuantizeLinear](https://github.com/onnx/onnx/blob/main/docs/Operators.md#DynamicQuantizeLinear) | |
| [Einsum](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Einsum) | |
| [Elu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Elu) | [6+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Elu-6) |
-| [Equal](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Equal) | [7-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Equal-7), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Equal-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Equal-13) |
+| [Equal](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Equal) | [7-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Equal-7), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Equal-11), [13-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Equal-13), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Equal-19) |
| [Erf](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Erf) | |
| [Exp](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Exp) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Exp-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Exp-13) |
| [Expand](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Expand) | |
@@ -79,7 +80,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [HardSigmoid](https://github.com/onnx/onnx/blob/main/docs/Operators.md#HardSigmoid) | |
| [HardSwish](https://github.com/onnx/onnx/blob/main/docs/Operators.md#HardSwish) | |
| [Hardmax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Hardmax) | |
-| [Identity](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Identity) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-1), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-13), [14-15](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-14), [16+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-16) |
+| [Identity](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Identity) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-1), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-13), [14-15](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-14), [16-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-16), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-19) |
| [If](https://github.com/onnx/onnx/blob/main/docs/Operators.md#If) | |
| [InstanceNormalization](https://github.com/onnx/onnx/blob/main/docs/Operators.md#InstanceNormalization) | [6+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#InstanceNormalization-6) |
| [IsInf](https://github.com/onnx/onnx/blob/main/docs/Operators.md#IsInf) | |
@@ -120,7 +121,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [OptionalHasElement](https://github.com/onnx/onnx/blob/main/docs/Operators.md#OptionalHasElement) | |
| [Or](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Or) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Or-7) |
| [PRelu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#PRelu) | [7-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#PRelu-7), [9-15](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#PRelu-9), [16+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#PRelu-16) |
-| [Pad](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Pad) | [2-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-2), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-18) |
+| [Pad](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Pad) | [2-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-2), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-13), [18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-18), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-19) |
| [Pow](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Pow) | [7-11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pow-7), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pow-12), [13-14](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pow-13), [15+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pow-15) |
| [QLinearConv](https://github.com/onnx/onnx/blob/main/docs/Operators.md#QLinearConv) | |
| [QLinearMatMul](https://github.com/onnx/onnx/blob/main/docs/Operators.md#QLinearMatMul) | |
@@ -143,8 +144,8 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [ReduceSum](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSum) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSum-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSum-11) |
| [ReduceSumSquare](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSumSquare) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-18) |
| [Relu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Relu) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-6), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-13), [14+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-14) |
-| [Reshape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Reshape) | [5-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-5), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-13), [14+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-14) |
-| [Resize](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize) | [10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-10), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-18) |
+| [Reshape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Reshape) | [5-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-5), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-13), [14-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-14), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-19) |
+| [Resize](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize) | [10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-10), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-13), [18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-18), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-19) |
| [ReverseSequence](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReverseSequence) | |
| [RoiAlign](https://github.com/onnx/onnx/blob/main/docs/Operators.md#RoiAlign) | |
| [Round](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Round) | |
@@ -161,7 +162,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [SequenceInsert](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SequenceInsert) | |
| [SequenceLength](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SequenceLength) | |
| [SequenceMap](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SequenceMap) | |
-| [Shape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-1), [13-14](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-13), [15+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-15) |
+| [Shape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-1), [13-14](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-13), [15-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-15), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-19) |
| [Shrink](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shrink) | |
| [Sigmoid](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sigmoid) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sigmoid-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sigmoid-13) |
| [Sign](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sign) | |
diff --git a/js/web/karma.conf.js b/js/web/karma.conf.js
index 2a4e71e064..860a7d2e20 100644
--- a/js/web/karma.conf.js
+++ b/js/web/karma.conf.js
@@ -63,6 +63,8 @@ module.exports = function (config) {
{ pattern: 'dist/ort-wasm-threaded.wasm', included: false },
{ pattern: 'dist/ort-wasm-simd.wasm', included: false },
{ pattern: 'dist/ort-wasm-simd-threaded.wasm', included: false },
+ { pattern: 'dist/ort-wasm-simd.jsep.wasm', included: false },
+ { pattern: 'dist/ort-wasm-simd-threaded.jsep.wasm', included: false },
{ pattern: 'dist/ort-wasm-threaded.worker.js', included: false },
],
proxies: {
@@ -70,6 +72,8 @@ module.exports = function (config) {
'/base/test/ort-wasm-threaded.wasm': '/base/dist/ort-wasm-threaded.wasm',
'/base/test/ort-wasm-simd.wasm': '/base/dist/ort-wasm-simd.wasm',
'/base/test/ort-wasm-simd-threaded.wasm': '/base/dist/ort-wasm-simd-threaded.wasm',
+ '/base/test/ort-wasm-simd.jsep.wasm': '/base/dist/ort-wasm-simd.jsep.wasm',
+ '/base/test/ort-wasm-simd-threaded.jsep.wasm': '/base/dist/ort-wasm-simd-threaded.jsep.wasm',
'/base/test/ort-wasm-threaded.worker.js': '/base/dist/ort-wasm-threaded.worker.js',
},
plugins: karmaPlugins,
diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts
index 683e15e884..7648f0c473 100644
--- a/js/web/lib/wasm/wasm-factory.ts
+++ b/js/web/lib/wasm/wasm-factory.ts
@@ -6,11 +6,16 @@ import * as path from 'path';
import {OrtWasmModule} from './binding/ort-wasm';
import {OrtWasmThreadedModule} from './binding/ort-wasm-threaded';
-import ortWasmFactory from './binding/ort-wasm.js';
-const ortWasmFactoryThreaded: EmscriptenModuleFactory =
- // eslint-disable-next-line @typescript-eslint/no-require-imports
- !BUILD_DEFS.DISABLE_WASM_THREAD ? require('./binding/ort-wasm-threaded.js') : ortWasmFactory;
+/* eslint-disable @typescript-eslint/no-require-imports */
+const ortWasmFactory: EmscriptenModuleFactory =
+ BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js');
+
+const ortWasmFactoryThreaded: EmscriptenModuleFactory = !BUILD_DEFS.DISABLE_WASM_THREAD ?
+ (BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm-threaded.js') :
+ require('./binding/ort-wasm-simd-threaded.jsep.js')) :
+ ortWasmFactory;
+/* eslint-enable @typescript-eslint/no-require-imports */
let wasm: OrtWasmModule|undefined;
let initialized = false;
@@ -95,10 +100,10 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise
const useThreads = numThreads > 1 && isMultiThreadSupported();
const useSimd = simd && isSimdSupported();
- const wasmPrefixOverride = typeof flags.wasmPaths === 'string' ? flags.wasmPaths : undefined;
- const wasmFileName = getWasmFileName(false, useThreads);
- const wasmOverrideFileName = getWasmFileName(useSimd, useThreads);
- const wasmPathOverride = typeof flags.wasmPaths === 'object' ? flags.wasmPaths[wasmOverrideFileName] : undefined;
+ const wasmPaths = flags.wasmPaths;
+ const wasmPrefixOverride = typeof wasmPaths === 'string' ? wasmPaths : undefined;
+ const wasmFileName = getWasmFileName(useSimd, useThreads);
+ const wasmPathOverride = typeof wasmPaths === 'object' ? wasmPaths[wasmFileName] : undefined;
let isTimeout = false;
@@ -130,9 +135,22 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise
{type: 'text/javascript'}));
}
- if (fileName === wasmFileName) {
- const prefix: string = wasmPrefixOverride ?? scriptDirectory;
- return wasmPathOverride ?? prefix + wasmOverrideFileName;
+ if (fileName.endsWith('.wasm')) {
+ if (wasmPathOverride) {
+ return wasmPathOverride;
+ }
+
+ const prefix = wasmPrefixOverride ?? scriptDirectory;
+
+ if (!BUILD_DEFS.DISABLE_WEBGPU) {
+ if (wasmFileName === 'ort-wasm-simd.wasm') {
+ return prefix + 'ort-wasm-simd.jsep.wasm';
+ } else if (wasmFileName === 'ort-wasm-simd-threaded.wasm') {
+ return prefix + 'ort-wasm-simd-threaded.jsep.wasm';
+ }
+ }
+
+ return prefix + wasmFileName;
}
return scriptDirectory + fileName;
diff --git a/js/web/package-lock.json b/js/web/package-lock.json
index ad44566290..dd18e3eac0 100644
--- a/js/web/package-lock.json
+++ b/js/web/package-lock.json
@@ -1,12 +1,12 @@
{
"name": "onnxruntime-web",
- "version": "1.15.0",
+ "version": "1.15.1",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "onnxruntime-web",
- "version": "1.15.0",
+ "version": "1.15.1",
"license": "MIT",
"dependencies": {
"flatbuffers": "^1.12.0",
@@ -52,7 +52,7 @@
},
"../common": {
"name": "onnxruntime-common",
- "version": "1.15.0",
+ "version": "1.15.1",
"license": "MIT",
"devDependencies": {
"typedoc": "^0.23.22"
@@ -1158,9 +1158,9 @@
}
},
"node_modules/engine.io": {
- "version": "6.4.1",
- "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.4.1.tgz",
- "integrity": "sha512-JFYQurD/nbsA5BSPmbaOSLa3tSVj8L6o4srSwXXY3NqE+gGUNmmPTbhn8tjzcCtSqhFgIeqef81ngny8JM25hw==",
+ "version": "6.4.2",
+ "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.4.2.tgz",
+ "integrity": "sha512-FKn/3oMiJjrOEOeUub2WCox6JhxBXq/Zn3fZOMCBxKnNYtsdKjxhl7yR3fZhM9PV+rdE75SU5SYMc+2PGzo+Tg==",
"dev": true,
"dependencies": {
"@types/cookie": "^0.4.1",
@@ -4933,9 +4933,9 @@
}
},
"engine.io": {
- "version": "6.4.1",
- "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.4.1.tgz",
- "integrity": "sha512-JFYQurD/nbsA5BSPmbaOSLa3tSVj8L6o4srSwXXY3NqE+gGUNmmPTbhn8tjzcCtSqhFgIeqef81ngny8JM25hw==",
+ "version": "6.4.2",
+ "resolved": "https://registry.npmjs.org/engine.io/-/engine.io-6.4.2.tgz",
+ "integrity": "sha512-FKn/3oMiJjrOEOeUub2WCox6JhxBXq/Zn3fZOMCBxKnNYtsdKjxhl7yR3fZhM9PV+rdE75SU5SYMc+2PGzo+Tg==",
"dev": true,
"requires": {
"@types/cookie": "^0.4.1",
diff --git a/js/web/package.json b/js/web/package.json
index 9b18dcbd23..0de6a2ef1f 100644
--- a/js/web/package.json
+++ b/js/web/package.json
@@ -8,7 +8,7 @@
"type": "git"
},
"author": "fs-eire",
- "version": "1.15.0",
+ "version": "1.15.1",
"jsdelivr": "dist/ort.min.js",
"dependencies": {
"flatbuffers": "^1.12.0",
@@ -66,6 +66,22 @@
"strip-json-comments": "^5.0.0"
},
"main": "dist/ort-web.node.js",
- "types": "./types/lib/index.d.ts",
+ "exports": {
+ ".": {
+ "node": {
+ "types": "./types.d.ts",
+ "default": "./dist/ort-web.node.js"
+ },
+ "default": {
+ "types": "./types.d.ts",
+ "default": "./dist/ort.min.js"
+ }
+ },
+ "./webgpu": {
+ "types": "./types.d.ts",
+ "default": "./dist/ort.webgpu.min.js"
+ }
+ },
+ "types": "./types.d.ts",
"description": "A Javascript library for running ONNX models on browsers"
}
diff --git a/js/web/script/build.ts b/js/web/script/build.ts
index ebdce462b9..d3a5be429b 100644
--- a/js/web/script/build.ts
+++ b/js/web/script/build.ts
@@ -34,8 +34,11 @@ const ROOT_FOLDER = path.join(__dirname, '..');
const WASM_BINDING_FOLDER = path.join(ROOT_FOLDER, 'lib', 'wasm', 'binding');
const WASM_BINDING_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm.js');
const WASM_BINDING_THREADED_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.js');
+const WASM_BINDING_SIMD_THREADED_JSEP_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-simd-threaded.jsep.js');
const WASM_BINDING_THREADED_WORKER_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.worker.js');
const WASM_BINDING_THREADED_MIN_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.min.js');
+const WASM_BINDING_SIMD_THREADED_JSEP_MIN_JS_PATH =
+ path.join(WASM_BINDING_FOLDER, 'ort-wasm-simd-threaded.jsep.min.js');
const WASM_BINDING_THREADED_MIN_WORKER_JS_PATH = path.join(WASM_BINDING_FOLDER, 'ort-wasm-threaded.min.worker.js');
const WASM_DIST_FOLDER = path.join(ROOT_FOLDER, 'dist');
@@ -43,8 +46,11 @@ const WASM_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm.wasm');
const WASM_THREADED_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-threaded.wasm');
const WASM_SIMD_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd.wasm');
const WASM_SIMD_THREADED_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd-threaded.wasm');
+const WASM_SIMD_JSEP_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd.jsep.wasm');
+const WASM_SIMD_THREADED_JSEP_WASM_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd-threaded.jsep.wasm');
const WASM_THREADED_WORKER_JS_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-threaded.worker.js');
const WASM_THREADED_JS_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-threaded.js');
+const WASM_SIMD_THREADED_JSEP_JS_PATH = path.join(WASM_DIST_FOLDER, 'ort-wasm-simd-threaded.jsep.js');
function validateFile(path: string): void {
npmlog.info('Build', `Ensure file: ${path}`);
@@ -61,11 +67,14 @@ if (WASM) {
try {
validateFile(WASM_BINDING_JS_PATH);
validateFile(WASM_BINDING_THREADED_JS_PATH);
+ validateFile(WASM_BINDING_SIMD_THREADED_JSEP_JS_PATH);
validateFile(WASM_BINDING_THREADED_WORKER_JS_PATH);
validateFile(WASM_WASM_PATH);
validateFile(WASM_THREADED_WASM_PATH);
validateFile(WASM_SIMD_WASM_PATH);
validateFile(WASM_SIMD_THREADED_WASM_PATH);
+ validateFile(WASM_SIMD_JSEP_WASM_PATH);
+ validateFile(WASM_SIMD_THREADED_JSEP_WASM_PATH);
} catch (e) {
npmlog.error('Build', `WebAssembly files are not ready. build WASM first. ERR: ${e}`);
throw e;
@@ -86,7 +95,7 @@ if (WASM) {
'npx',
[
'terser', WASM_BINDING_THREADED_JS_PATH, '--compress', 'passes=2', '--format', 'comments=false', '--mangle',
- 'reserved=[_scriptDir]', '--module'
+ 'reserved=[_scriptDir,startWorker]', '--module'
],
{shell: true, encoding: 'utf-8', cwd: ROOT_FOLDER});
if (terser.status !== 0) {
@@ -105,13 +114,38 @@ if (WASM) {
}
npmlog.info('Build', 'Minimizing file "ort-wasm-threaded.js"... DONE');
+ npmlog.info('Build', 'Minimizing file "ort-wasm-simd-threaded.jsep.js"...');
+ try {
+ const terser = spawnSync(
+ 'npx',
+ [
+ 'terser', WASM_BINDING_SIMD_THREADED_JSEP_JS_PATH, '--compress', 'passes=2', '--format', 'comments=false',
+ '--mangle', 'reserved=[_scriptDir,startWorker]', '--module'
+ ],
+ {shell: true, encoding: 'utf-8', cwd: ROOT_FOLDER});
+ if (terser.status !== 0) {
+ console.error(terser.error);
+ process.exit(terser.status === null ? undefined : terser.status);
+ }
+
+ fs.writeFileSync(WASM_BINDING_SIMD_THREADED_JSEP_MIN_JS_PATH, terser.stdout);
+ fs.writeFileSync(WASM_SIMD_THREADED_JSEP_JS_PATH, `${COPYRIGHT_BANNER}${terser.stdout}`);
+
+ validateFile(WASM_BINDING_SIMD_THREADED_JSEP_MIN_JS_PATH);
+ validateFile(WASM_SIMD_THREADED_JSEP_JS_PATH);
+ } catch (e) {
+ npmlog.error('Build', `Failed to run terser on ort-wasm-threaded.js. ERR: ${e}`);
+ throw e;
+ }
+ npmlog.info('Build', 'Minimizing file "ort-wasm-simd-threaded.jsep.js"... DONE');
+
npmlog.info('Build', 'Minimizing file "ort-wasm-threaded.worker.js"...');
try {
const terser = spawnSync(
'npx',
[
'terser', WASM_BINDING_THREADED_WORKER_JS_PATH, '--compress', 'passes=2', '--format', 'comments=false',
- '--mangle', 'reserved=[_scriptDir]', '--toplevel'
+ '--mangle', 'reserved=[_scriptDir,startWorker]', '--toplevel'
],
{shell: true, encoding: 'utf-8'});
if (terser.status !== 0) {
diff --git a/js/web/script/pull-prebuilt-wasm-artifacts.ts b/js/web/script/pull-prebuilt-wasm-artifacts.ts
index bd6fa1af4f..8c2f24cbf7 100644
--- a/js/web/script/pull-prebuilt-wasm-artifacts.ts
+++ b/js/web/script/pull-prebuilt-wasm-artifacts.ts
@@ -112,10 +112,14 @@ downloadJson(
extractFile(zip, WASM_FOLDER, 'ort-wasm-threaded.wasm', 'Release_wasm');
extractFile(zip, WASM_FOLDER, 'ort-wasm-simd.wasm', 'Release_wasm');
extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.wasm', 'Release_wasm');
+ extractFile(zip, WASM_FOLDER, 'ort-wasm-simd.jsep.wasm', 'Release_wasm');
+ extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.jsep.wasm', 'Release_wasm');
extractFile(zip, JS_FOLDER, 'ort-wasm.js', 'Release_wasm');
extractFile(zip, JS_FOLDER, 'ort-wasm-threaded.js', 'Release_wasm');
extractFile(zip, JS_FOLDER, 'ort-wasm-threaded.worker.js', 'Release_wasm');
+ extractFile(zip, JS_FOLDER, 'ort-wasm-simd.jsep.js', 'Release_wasm');
+ extractFile(zip, JS_FOLDER, 'ort-wasm-simd-threaded.jsep.js', 'Release_wasm');
});
});
});
diff --git a/js/web/types.d.ts b/js/web/types.d.ts
new file mode 100644
index 0000000000..c6cff64c8a
--- /dev/null
+++ b/js/web/types.d.ts
@@ -0,0 +1,10 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+declare module 'onnxruntime-web' {
+ export * from 'onnxruntime-common';
+}
+
+declare module 'onnxruntime-web/webgpu' {
+ export * from 'onnxruntime-web';
+}
diff --git a/js/web/webpack.config.js b/js/web/webpack.config.js
index 1c842ddced..84ab9bcae8 100644
--- a/js/web/webpack.config.js
+++ b/js/web/webpack.config.js
@@ -49,7 +49,7 @@ function defaultTerserPluginOptions(target) {
passes: 2
},
mangle: {
- reserved: ["_scriptDir"]
+ reserved: ["_scriptDir","startWorker"]
}
}
};
@@ -57,7 +57,7 @@ function defaultTerserPluginOptions(target) {
const DEFAULT_BUILD_DEFS = {
DISABLE_WEBGL: false,
- DISABLE_WEBGPU: false,
+ DISABLE_WEBGPU: true,
DISABLE_WASM: false,
DISABLE_WASM_PROXY: false,
DISABLE_WASM_THREAD: false,
@@ -123,6 +123,7 @@ function buildConfig({ filename, format, target, mode, devtool, build_defs }) {
if (mode === 'production') {
config.resolve.alias['./binding/ort-wasm-threaded.js'] = './binding/ort-wasm-threaded.min.js';
+ config.resolve.alias['./binding/ort-wasm-threaded-simd.jsep.js'] = './binding/ort-wasm-threaded-simd.jsep.min.js';
config.resolve.alias['./binding/ort-wasm-threaded.worker.js'] = './binding/ort-wasm-threaded.min.worker.js';
const options = defaultTerserPluginOptions(target);
@@ -211,7 +212,8 @@ function buildTestRunnerConfig({
format = 'umd',
target = 'es2017',
mode = 'production',
- devtool = 'source-map'
+ devtool = 'source-map',
+ build_defs = DEFAULT_BUILD_DEFS
}) {
const config = {
target: ['web', target],
@@ -244,7 +246,7 @@ function buildTestRunnerConfig({
}
},
plugins: [
- new webpack.DefinePlugin({ BUILD_DEFS: DEFAULT_BUILD_DEFS }),
+ new webpack.DefinePlugin({ BUILD_DEFS: build_defs }),
new webpack.WatchIgnorePlugin({ paths: [/\.js$/, /\.d\.ts$/] }),
new NodePolyfillPlugin({
excludeAliases: ["console", "Buffer"]
@@ -315,6 +317,13 @@ module.exports = () => {
DISABLE_WASM_THREAD: true,
}
}),
+ // ort.webgpu.min.js
+ buildOrtConfig({
+ suffix: '.webgpu.min', build_defs: {
+ ...DEFAULT_BUILD_DEFS,
+ DISABLE_WEBGPU: false,
+ }
+ }),
// ort-web.min.js
buildOrtWebConfig({ suffix: '.min' }),
@@ -333,10 +342,20 @@ module.exports = () => {
);
break;
case 'dev':
- builds.push(buildTestRunnerConfig({ suffix: '.dev', mode: 'development', devtool: 'inline-source-map' }));
+ builds.push(buildTestRunnerConfig({
+ suffix: '.dev', mode: 'development', devtool: 'inline-source-map', build_defs: {
+ ...DEFAULT_BUILD_DEFS,
+ DISABLE_WEBGPU: false,
+ }
+ }));
break;
case 'perf':
- builds.push(buildTestRunnerConfig({ suffix: '.perf' }));
+ builds.push(buildTestRunnerConfig({
+ suffix: '.perf', build_defs: {
+ ...DEFAULT_BUILD_DEFS,
+ DISABLE_WEBGPU: false,
+ }
+ }));
break;
default:
throw new Error(`unsupported bundle mode: ${bundleMode}`);
diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py
index 1da5824cf7..14aa188ac0 100644
--- a/onnxruntime/__init__.py
+++ b/onnxruntime/__init__.py
@@ -7,7 +7,7 @@
For more information on ONNX Runtime, please see `aka.ms/onnxruntime `_
or the `Github project `_.
"""
-__version__ = "1.15.0"
+__version__ = "1.15.1"
__author__ = "Microsoft"
# we need to do device version validation (for example to check Cuda version for an onnxruntime-training package).
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index 31b1b12d38..90f5abab87 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -48,6 +48,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Quick
// ******** Start: Quantization ******************* //
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearGlobalAveragePool);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearConvTranspose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearConcat);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearWhere);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearAveragePool);
@@ -181,6 +182,7 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cpu/tokenizer.cc b/onnxruntime/contrib_ops/cpu/tokenizer.cc
index 45998b6d83..1787fb9b3c 100644
--- a/onnxruntime/contrib_ops/cpu/tokenizer.cc
+++ b/onnxruntime/contrib_ops/cpu/tokenizer.cc
@@ -242,7 +242,7 @@ Status Tokenizer::SeparatorExpressionTokenizer(OpKernelContext* ctx,
token_len, utf8_chars);
if (!valid) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
- "Match contains invalid utf8 chars: " + submatch.as_string());
+ "Match contains invalid utf8 chars: " + std::string{submatch});
}
if (utf8_chars >= size_t(mincharnum_)) {
tokens.emplace_back(text.data() + start_pos, token_len);
@@ -384,7 +384,7 @@ Status Tokenizer::TokenExpression(OpKernelContext* ctx,
utf8_chars = 0;
if (!utf8_len(reinterpret_cast(submatch.data()), token_len, utf8_chars)) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
- "Match contains invalid utf8 chars: " + submatch.as_string());
+ "Match contains invalid utf8 chars: " + std::string{submatch});
}
if (utf8_chars >= size_t(mincharnum_)) {
row.push_back(submatch);
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc
index def1508ca2..bd1498c0f9 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc
@@ -164,6 +164,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
has_memory_efficient_attention(sm, sizeof(T) == 2);
#else
constexpr bool use_memory_efficient_attention = false;
+ ORT_UNUSED_VARIABLE(is_mask_1d_key_seq_len_start);
#endif
cublasHandle_t cublas = GetCublasHandle(context);
diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu
index b5c8dcca32..ebe87158d1 100644
--- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias_impl.cu
@@ -156,7 +156,7 @@ struct TypeMapper : public V_vec_m_ {};
// The following operator overriding is not common so we put it in anonymous namespace
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ > 530
inline __device__ half2 operator*(const float a, const half2 b) {
- return __hmul2_rn(__float2half2_rn(a), b);
+ return __hmul2(__float2half2_rn(a), b);
}
#else
inline __device__ half2 operator*(const float a, const half2 b) {
diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc
index a6e040812a..1334708aad 100644
--- a/onnxruntime/core/framework/allocation_planner.cc
+++ b/onnxruntime/core/framework/allocation_planner.cc
@@ -1324,6 +1324,7 @@ class PlannerImpl {
#endif
ORT_RETURN_IF_ERROR(ComputeSingleStreamReusePlan(i));
ClearUseCount();
+ freelist_.clear(); // DONOT share freelist across streams
}
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
CalculateLifetime(ort_value_usecount);
@@ -1745,9 +1746,8 @@ class PlannerImpl {
#else
- void
- PartitionIntoStreams(const logging::Logger& logger, const ExecutionProviders& execution_providers,
- const PathString& partition_config_file) {
+ void PartitionIntoStreams(const logging::Logger& logger, const ExecutionProviders& execution_providers,
+ const PathString& partition_config_file) {
auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger, partition_config_file);
auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_, context_->GetExecutionOrder());
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
@@ -1760,7 +1760,7 @@ class PlannerImpl {
num_logic_streams_ = stream_nodes_.size();
}
- // build each logic streams
+ // Build each logic streams
Status BuildExecutionPlan(const ExecutionProviders& execution_providers,
const IStreamCommandHandleRegistry& stream_handle_registry) {
// 1. create logic stream instance
@@ -1780,12 +1780,12 @@ class PlannerImpl {
execution_plan.emplace_back(nullptr);
}
}
- // 2. determing following things:
- // a. which node need to generate notification
- // b. which node need to trigger downstream
+ // 2. Determining following things:
+ // a. which node needs to generate the notification
+ // b. which node needs to trigger downstream
#ifdef ENABLE_TRAINING
// We will leverage the topological order for the training scenario.
- // The nodes before yieldOp in topo order will be executed in RunForward() and nodes after will be executed in RunBackward()
+ // The nodes before yieldOp in topo-order will be executed in RunForward() and nodes after will be executed in RunBackward()
// This partition may not be exactly the same as forward model/gradient model, for example, some nodes in gradient model are
// before yieldOp thus will be executed in RunForward()
// But the final result is still correct, as long as all the nodes will be executed in either RunForward() or RunBackward()
@@ -1820,7 +1820,7 @@ class PlannerImpl {
if (node_stream_map_[it->Index()] != i
#ifdef ENABLE_TRAINING
// Do not insert Barrier/TriggerDownStream step if the producer and consumer are in different sides of yieldOp
- // As in this case producer will surely be ready before consumer is running.
+ // As in this case producer will surely be ready before the consumer is running.
&& !AreNodesSeparatedByYield(node_index, it->Index())
#endif
) {
@@ -2048,8 +2048,7 @@ class PlannerImpl {
}
#endif
- static bool
- IsNonTensor(const onnxruntime::NodeArg& nodearg) {
+ static bool IsNonTensor(const onnxruntime::NodeArg& nodearg) {
// TODO: unclear why we should go through a string-representation of type
auto ptype = nodearg.Type();
auto& type_proto = ONNX_NAMESPACE::Utils::DataTypeUtils::ToTypeProto(ptype);
diff --git a/onnxruntime/core/framework/cloud_invoker.cc b/onnxruntime/core/framework/cloud_invoker.cc
index d6883d408f..a2e9ec97cf 100644
--- a/onnxruntime/core/framework/cloud_invoker.cc
+++ b/onnxruntime/core/framework/cloud_invoker.cc
@@ -2,7 +2,9 @@
// Licensed under the MIT License.
#ifdef USE_AZURE
+#define CURL_STATICLIB
#include "http_client.h"
+#include "curl/curl.h"
#include "core/common/common.h"
#include "core/framework/cloud_invoker.h"
#include "core/framework/ort_value.h"
@@ -18,13 +20,14 @@ namespace onnxruntime {
namespace tc = triton::client;
-const char* kAzureUri = "azure.uri";
-const char* kAzureModelName = "azure.model_name";
-const char* kAzureModelVer = "azure.model_version";
-const char* kAzureVerbose = "azure.verbose";
-const char* kAzureEndpointType = "azure.endpoint_type";
-const char* kAzureAuthKey = "azure.auth_key";
-const char* kAzureTriton = "triton";
+constexpr const char* kAzureUri = "azure.uri";
+constexpr const char* kAzureModelName = "azure.model_name";
+constexpr const char* kAzureModelVer = "azure.model_version";
+constexpr const char* kAzureVerbose = "azure.verbose";
+constexpr const char* kAzureEndpointType = "azure.endpoint_type";
+constexpr const char* kAzureAuthKey = "azure.auth_key";
+constexpr const char* kAzureTriton = "triton";
+constexpr const char* kAzureOpenAI = "openai";
CloudEndPointInvoker::CloudEndPointInvoker(const CloudEndPointConfig& config,
const AllocatorPtr& allocator) : config_(config), allocator_(allocator) {
@@ -33,6 +36,163 @@ CloudEndPointInvoker::CloudEndPointInvoker(const CloudEndPointConfig& config,
}
}
+class CurlGlobal {
+ public:
+ static void Init() {
+ static CurlGlobal curl_global;
+ }
+
+ private:
+ CurlGlobal() {
+ // Thread-safety is a must since curl might also be initialized in triton client.
+ const auto* info = curl_version_info(CURLVERSION_NOW);
+ ORT_ENFORCE(info->features & CURL_VERSION_THREADSAFE, "curl global init not thread-safe, need to upgrade curl version!");
+ ORT_ENFORCE(curl_global_init(CURL_GLOBAL_DEFAULT) == CURLE_OK, "Failed to initialize curl global env!");
+ }
+ ~CurlGlobal() {
+ curl_global_cleanup();
+ }
+ ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CurlGlobal);
+};
+
+// OpenAIInvoker
+class OpenAIInvoker : public CloudEndPointInvoker {
+ public:
+ OpenAIInvoker(const CloudEndPointConfig& config, const AllocatorPtr& allocator);
+ onnxruntime::Status Send(const CloudEndPointConfig& run_options,
+ const InlinedVector& input_names,
+ gsl::span ort_inputs,
+ const InlinedVector& output_names,
+ std::vector& ort_outputs) const override;
+
+ private:
+ std::string uri_;
+ std::string model_name_;
+};
+
+OpenAIInvoker::OpenAIInvoker(const CloudEndPointConfig& config,
+ const AllocatorPtr& allocator) : CloudEndPointInvoker(config, allocator) {
+ CurlGlobal::Init();
+ ReadConfig(kAzureUri, uri_);
+ ReadConfig(kAzureModelName, model_name_);
+}
+
+struct StringBuffer {
+ StringBuffer() = default;
+ ~StringBuffer() = default;
+ ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(StringBuffer);
+ std::stringstream ss_;
+};
+
+// apply the callback only when response is for sure to be a '/0' terminated string
+static size_t WriteStringCallback(void* contents, size_t size, size_t nmemb, void* userp) {
+ try {
+ size_t realsize = size * nmemb;
+ auto buffer = reinterpret_cast(userp);
+ buffer->ss_.write(reinterpret_cast(contents), realsize);
+ return realsize;
+ } catch (...) {
+ // exception caught, abort write
+ return CURLcode::CURLE_WRITE_ERROR;
+ }
+}
+
+using CurlWriteCallBack = size_t (*)(void*, size_t, size_t, void*);
+
+class CurlHandler {
+ public:
+ CurlHandler(CurlWriteCallBack call_back) : curl_(curl_easy_init(), curl_easy_cleanup),
+ headers_(nullptr, curl_slist_free_all),
+ from_holder_(from_, curl_formfree) {
+ curl_easy_setopt(curl_.get(), CURLOPT_BUFFERSIZE, 102400L);
+ curl_easy_setopt(curl_.get(), CURLOPT_NOPROGRESS, 1L);
+ curl_easy_setopt(curl_.get(), CURLOPT_USERAGENT, "curl/7.83.1");
+ curl_easy_setopt(curl_.get(), CURLOPT_MAXREDIRS, 50L);
+ curl_easy_setopt(curl_.get(), CURLOPT_FTP_SKIP_PASV_IP, 1L);
+ curl_easy_setopt(curl_.get(), CURLOPT_TCP_KEEPALIVE, 1L);
+ curl_easy_setopt(curl_.get(), CURLOPT_WRITEFUNCTION, call_back);
+ }
+ ~CurlHandler() = default;
+
+ void AddHeader(const char* data) {
+ headers_.reset(curl_slist_append(headers_.release(), data));
+ }
+ template
+ void AddForm(Args... args) {
+ curl_formadd(&from_, &last_, args...);
+ }
+ template
+ void SetOption(CURLoption opt, T val) {
+ curl_easy_setopt(curl_.get(), opt, val);
+ }
+ ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CurlHandler);
+ CURLcode Perform() {
+ SetOption(CURLOPT_HTTPHEADER, headers_.get());
+ SetOption(CURLOPT_HTTPPOST, from_);
+ return curl_easy_perform(curl_.get());
+ }
+
+ private:
+ std::unique_ptr curl_;
+ std::unique_ptr headers_;
+ curl_httppost* from_{};
+ curl_httppost* last_{};
+ std::unique_ptr from_holder_;
+};
+
+onnxruntime::Status OpenAIInvoker::Send(const CloudEndPointConfig& run_options,
+ const InlinedVector& /*input_names*/,
+ gsl::span ort_inputs,
+ const InlinedVector& /*output_names*/,
+ std::vector& ort_outputs) const {
+ const auto auth_key_iter = run_options.find(kAzureAuthKey);
+ if (run_options.end() == auth_key_iter || auth_key_iter->second.empty()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "auth key must be specified for openai client");
+ }
+ long verbose = 0;
+ const auto verbose_iter = run_options.find(kAzureVerbose);
+ if (run_options.end() != verbose_iter) {
+ verbose = verbose_iter->second != "0" ? 1L : 0L;
+ }
+
+ CurlHandler curl_handler(WriteStringCallback);
+ StringBuffer string_buffer;
+
+ std::string full_auth = std::string{"Authorization: Bearer "} + auth_key_iter->second;
+ curl_handler.AddHeader(full_auth.c_str());
+ curl_handler.AddHeader("Content-Type: multipart/form-data");
+
+ const auto& tensor = ort_inputs[0].Get();
+ auto data_size = tensor.SizeInBytes();
+ curl_handler.AddForm(CURLFORM_COPYNAME, "model", CURLFORM_COPYCONTENTS, model_name_.c_str(), CURLFORM_END);
+ curl_handler.AddForm(CURLFORM_COPYNAME, "response_format", CURLFORM_COPYCONTENTS, "text", CURLFORM_END);
+ curl_handler.AddForm(CURLFORM_COPYNAME, "file", CURLFORM_BUFFER, "non_exist.wav", CURLFORM_BUFFERPTR, tensor.DataRaw(),
+ CURLFORM_BUFFERLENGTH, data_size, CURLFORM_END);
+
+ curl_handler.SetOption(CURLOPT_URL, uri_.c_str());
+ curl_handler.SetOption(CURLOPT_VERBOSE, verbose);
+ curl_handler.SetOption(CURLOPT_WRITEDATA, (void*)&string_buffer);
+
+ auto curl_ret = curl_handler.Perform();
+ if (CURLE_OK != curl_ret) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, curl_easy_strerror(curl_ret));
+ }
+
+ auto output_tensor = std::make_unique(onnxruntime::DataTypeImpl::GetType(), TensorShape{1}, allocator_);
+ if (!output_tensor) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create output tensor");
+ }
+
+ auto* output_string = output_tensor->MutableData();
+ *output_string = string_buffer.ss_.str();
+ auto tensor_type = DataTypeImpl::GetType();
+ ort_outputs.clear();
+ ort_outputs.emplace_back(output_tensor.release(), tensor_type, tensor_type->GetDeleteFunc());
+ return Status::OK();
+}
+
+// AzureTritonInvoker
class AzureTritonInvoker : public CloudEndPointInvoker {
public:
AzureTritonInvoker(const CloudEndPointConfig& config, const AllocatorPtr& allocator);
@@ -287,6 +447,9 @@ Status CloudEndPointInvoker::CreateInvoker(const CloudEndPointConfig& config,
if (iter->second == kAzureTriton) {
invoker = std::make_unique(config, allocator);
return status;
+ } else if (iter->second == kAzureOpenAI) {
+ invoker = std::make_unique(config, allocator);
+ return status;
} // else other endpoint types ...
}
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc
index 00aeff37d7..d9c49dc6be 100644
--- a/onnxruntime/core/framework/execution_frame.cc
+++ b/onnxruntime/core/framework/execution_frame.cc
@@ -28,6 +28,7 @@ using namespace onnxruntime::common;
namespace onnxruntime {
#ifdef ORT_ENABLE_STREAM
static StreamAwareArena* AsStreamBasedAllocator(AllocatorPtr allocator) {
+ ORT_ENFORCE(allocator.get() != nullptr, "allocator is nullptr");
if (allocator->Info().alloc_type == OrtArenaAllocator) {
BFCArena* arena_ptr = static_cast(allocator.get());
return StreamAwareArena::FromBFCArena(*arena_ptr);
@@ -137,7 +138,7 @@ Status IExecutionFrame::GetOutputs(gsl::span fetch_mlvalue_idxs, std:
#endif
-// Return nullptr if index map to an value that is an unused optional input/output
+// Return nullptr if index map to a value that is an unused optional input/output
const OrtValue* IExecutionFrame::GetNodeInputOrOutputMLValue(int index) const {
int ort_value_idx = GetNodeIdxToMLValueIdx(index);
return ort_value_idx != NodeIndexInfo::kInvalidEntry ? &(all_values_[ort_value_idx]) : nullptr;
@@ -147,9 +148,9 @@ OrtValue* IExecutionFrame::GetMutableNodeInputOrOutputMLValue(int index) {
return const_cast(GetNodeInputOrOutputMLValue(index));
}
-// TO DO: make it thread safe
-// This method is not thread safe!
-// Return S_OK and nullptr if index map to an value that is an unused optional input/output
+// TO DO: make it thread-safe
+// This method is not thread-safe!
+// Return S_OK and nullptr if index map to a value that is an unused optional input/output
Status IExecutionFrame::GetOrCreateNodeOutputMLValue(const int output_index, int output_arg_index,
const TensorShape* shape, OrtValue*& p_ort_value,
@@ -191,7 +192,7 @@ Status IExecutionFrame::GetOrCreateNodeOutputMLValue(const int output_index, int
}
bool IExecutionFrame::TryGetInferredShape(int /*index*/, TensorShape& /*shape*/) const {
- // By default, there is not information about inferred shape, so this default
+ // By default, there is no information about inferred shape, so this default
// implementation always returns false. The derived class of IExecutionFrame
// can override this function to provide, for example, activations' shape information.
return false;
@@ -213,7 +214,7 @@ Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) {
}
int IExecutionFrame::GetNodeIdxToMLValueIdx(int index) const {
- // the validity of index is checked by GetMLValueIndex
+ // The validity of the index is checked by GetMLValueIndex
int ort_value_idx = node_index_info_.GetMLValueIndex(index);
return ort_value_idx;
}
@@ -241,7 +242,7 @@ void IExecutionFrame::Init(gsl::span feed_mlvalue_idxs, gsl::span feed_mlvalue_idxs, gsl::span feed_mlvalue_idxs, gsl::span feed_mlvalue_idxs, gsl::span
// Planning of one memory type should only happen once.
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
ORT_ENFORCE(
- static_activation_memory_sizes_in_byte_.find(location.name) ==
+ static_activation_memory_sizes_in_byte_.find(location.ToString()) ==
static_activation_memory_sizes_in_byte_.end(),
"Memory type ",
- location.name,
+ location.ToString(),
" should only appear once.");
// static_activation_memory_in_bytes_ is max virtual memory size the planner computes.
// Memory dynamically allocated when executing kernels is not recorded using this field.
- static_activation_memory_sizes_in_byte_[location.name] = peak_size;
+ static_activation_memory_sizes_in_byte_[location.ToString()] = peak_size;
#endif
// the memory pattern buffer will leave in the whole execution.
#ifdef ORT_ENABLE_STREAM
@@ -535,7 +536,7 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va
AllocatorPtr alloc = nullptr;
// if we have pre-calculated memory pattern, and the ort_value is not output mlvalue
- // try to allocated on pre-allocated big chunk.
+ // try to allocate on pre-allocated big chunk.
const auto& per_alloc_plan = GetAllocationPlan(ort_value_index);
if (mem_patterns_ && per_alloc_plan.alloc_kind != AllocKind::kAllocateOutput &&
@@ -557,11 +558,11 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va
} else {
// the block size may vary especially if the model has NonZero ops, or different sequence lengths are
// fed in, so use VERBOSE as the log level as it's expected.
- // TODO: Should we re-use the block if the size is large enough? Would probably need to allow it
+ // TODO: Should we reuse the block if the size is large enough? Would probably need to allow it
// to be freed if the size difference was too large so our memory usage doesn't stick at a high water mark
LOGS(session_state_.Logger(), VERBOSE) << "For ort_value with index: " << ort_value_index
<< ", block in memory pattern size is: " << block->size_
- << " but the actually size is: " << size
+ << " but the actual size is: " << size
<< ", fall back to default allocation behavior";
}
}
@@ -572,6 +573,8 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va
// no memory pattern, or the pattern is not correct.
if (!alloc) alloc = GetAllocator(location);
+ ORT_ENFORCE(alloc && alloc.get() != nullptr, "Failed to get allocator for ", location.ToString());
+
Stream* current_stream = GetValueStream(ort_value_index);
if (current_stream) {
#ifdef ORT_ENABLE_STREAM
@@ -606,7 +609,7 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va
// Dynamic activation size would be accessed by multiple threads
// if parallel executor is used.
std::unique_lock lock(mtx_);
- dynamic_activation_memory_sizes_in_byte_[location.name] += size;
+ dynamic_activation_memory_sizes_in_byte_[location.ToString()] += size;
session_state_.GetMemoryProfiler()->GetMemoryInfo().SetDynamicAllocation(ort_value_index);
#endif
}
@@ -825,7 +828,7 @@ AllocatorPtr ExecutionFrame::GetAllocatorImpl(const OrtDevice& info) const {
}
// This method is not thread safe!
-// Return S_OK and nullptr if index map to an value that is an unused optional input/output
+// Return S_OK and nullptr if index map to a value that is an unused optional input/output
Status ExecutionFrame::CreateNodeOutputMLValueImpl(OrtValue& ort_value, int ort_value_idx, const TensorShape* shape) {
return AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape);
}
@@ -930,7 +933,7 @@ bool ExecutionFrame::TryGetInferredShape(int index, TensorShape& shape) const {
}
// Search for inferred shape.
- // If inferred shape is found, it's assigned to "shape" so that caller can use it.
+ // If the inferred shape is found, it's assigned to "shape" so that caller can use it.
if (inferred_shapes_ != nullptr) {
auto it = inferred_shapes_->find(ort_value_idx);
if (it != inferred_shapes_->end()) {
diff --git a/onnxruntime/core/framework/memory_info.cc b/onnxruntime/core/framework/memory_info.cc
index 41c8191778..afb338f47f 100644
--- a/onnxruntime/core/framework/memory_info.cc
+++ b/onnxruntime/core/framework/memory_info.cc
@@ -130,7 +130,7 @@ void PrintInforPerTensor(const MemoryInfo::AllocInfoPerTensor& alloc_info, const
std::cout << "Index: " << alloc_info.mlvalue_index << ", ";
std::cout << "Reuse inplace: " << alloc_info.inplace_reuse << ", ";
std::cout << "Alloc type: " << alloc_info.alloc_kind << ", ";
- std::cout << "Location: " << alloc_info.location.name << ", ";
+ std::cout << "Location: " << alloc_info.location.ToString() << ", ";
std::cout << "lifetime: (" << alloc_info.lifetime_interval.first << ", " << alloc_info.lifetime_interval.second << "), ";
std::cout << "planned block: (" << mem_info.planned_block.offset_ << ", "
<< (mem_info.planned_block.offset_ + mem_info.planned_block.size_) << "), ";
@@ -142,24 +142,24 @@ void PrintInforPerTensor(const MemoryInfo::AllocInfoPerTensor& alloc_info, const
void MemoryInfo::PrintMemoryInfoForLocation(const OrtDevice::DeviceType location) {
for (const auto& map : tensors_memory_info_map_) {
- if (map.first.device.Type() != location) continue;
- std::cout << "Initializer in " << map.first.name << "\n";
+ if (map.first.Type() != location) continue;
+ std::cout << "Initializer in " << map.first.ToString() << "\n";
const auto& initailizer_map = map.second.at(MapType::Initializer);
for (const auto& item : initailizer_map) {
- if (AllocPlan(item.first)->location.device.Type() != location) continue;
+ if (AllocPlan(item.first)->location.Type() != location) continue;
PrintInforPerTensor(*AllocPlan(item.first), item.second, initailizer_map.GetAllocAddress(item.first));
}
- std::cout << "\nStatic Activation in " << map.first.name << "\n";
+ std::cout << "\nStatic Activation in " << map.first.ToString() << "\n";
const auto& static_activation_map = map.second.at(MapType::StaticActivation);
for (const auto& item : static_activation_map) {
- if (AllocPlan(item.first)->location.device.Type() != location) continue;
+ if (AllocPlan(item.first)->location.Type() != location) continue;
PrintInforPerTensor(*AllocPlan(item.first), item.second, static_activation_map.GetAllocAddress(item.first));
}
- std::cout << "\nDynamic Activation in " << map.first.name << "\n";
+ std::cout << "\nDynamic Activation in " << map.first.ToString() << "\n";
const auto& dynamic_activation_map = map.second.at(MapType::DynamicActivation);
for (const auto& item : dynamic_activation_map) {
- if (AllocPlan(item.first)->location.device.Type() != location) continue;
+ if (AllocPlan(item.first)->location.Type() != location) continue;
PrintInforPerTensor(*AllocPlan(item.first), item.second, dynamic_activation_map.GetAllocAddress(item.first));
}
}
@@ -273,10 +273,10 @@ void MemoryProfiler::CreateEvents(const std::string& p_name,
// Create Event for each tensor
auto& time_step_trace = GetMemoryInfo().time_step_trace_;
for (const auto& location_map : GetMemoryInfo().tensors_memory_info_map_) {
- const OrtMemoryInfo& memory_info = location_map.first;
+ const OrtDevice& memory_info = location_map.first;
const auto& maptype_to_map_mapping = location_map.second;
- if (memory_info.device.Type() != device_type) continue;
+ if (memory_info.Type() != device_type) continue;
// If there is no tensor of a map_type, we skip creating event for that map_type
if (maptype_to_map_mapping.find(map_type) == maptype_to_map_mapping.end()) continue;
diff --git a/onnxruntime/core/framework/memory_info.h b/onnxruntime/core/framework/memory_info.h
index 4b11dcaf7a..e38ccb3d94 100644
--- a/onnxruntime/core/framework/memory_info.h
+++ b/onnxruntime/core/framework/memory_info.h
@@ -122,7 +122,7 @@ class MemoryInfo {
bool inplace_reuse{false};
OrtValueIndex reused_buffer{0}; // The index of the reused tensor, if no reuse, it is its own tensor.
AllocKind alloc_kind{AllocKind::kAllocate};
- OrtMemoryInfo location;
+ OrtDevice location;
};
struct AllocationSummary {
@@ -185,7 +185,7 @@ class MemoryInfo {
}
// Key: The map type. E.g., initializer, activation. Value: A map from the tensor index to its memory information
- std::map > tensors_memory_info_map_;
+ std::map > tensors_memory_info_map_;
// Key: The tensor index. Value: The Allocation information per tensor
std::map tensor_alloc_info_map_;
diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h
index c4412648d7..e9ed36294d 100644
--- a/onnxruntime/core/graph/contrib_ops/ms_opset.h
+++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h
@@ -27,6 +27,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QEmbedLayerNormalization
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QGemm);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearAdd);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearConcat);
+class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearConvTranspose);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearWhere);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearLeakyRelu);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QLinearMul);
@@ -109,6 +110,7 @@ class OpSet_Microsoft_ver1 {
static void ForEachSchema(std::function fn) {
fn(GetOpSchema());
fn(GetOpSchema());
+ fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h
index d44e0ab155..fd6b3df934 100644
--- a/onnxruntime/core/mlas/inc/mlas.h
+++ b/onnxruntime/core/mlas/inc/mlas.h
@@ -1439,7 +1439,7 @@ class MLAS_HALF_GEMM_POSTPROCESSOR {
/**
* @brief Half precision activation functions, with optional sum tensor.
* Supplied sum tensor must be the same layout as the GEMM output tensor.
- * And the supplied sum tensor will be added to the final result.
+ * And the supplied sum tensor will be added to the tensor before activation.
*/
class MLAS_HALF_GEMM_ACTIVATION_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR
{
diff --git a/onnxruntime/core/mlas/lib/activate_fp16.cpp b/onnxruntime/core/mlas/lib/activate_fp16.cpp
index f62ffe0db9..776ec67fcc 100644
--- a/onnxruntime/core/mlas/lib/activate_fp16.cpp
+++ b/onnxruntime/core/mlas/lib/activate_fp16.cpp
@@ -689,35 +689,35 @@ MlasActivationKernel(
size_t n = CountN;
while (n >= 8) {
- MLAS_FLOAT16X8 Vector = MlasLoadFloat16x8(buffer);
MLAS_FLOAT16X8 AVec = MlasLoadFloat16x8(addsrc);
+ MLAS_FLOAT16X8 Vector = MlasLoadFloat16x8(buffer);
addsrc += 8;
- Vector = ActivationFunction.Activate(Vector);
Vector = MlasAddFloat16x8(Vector, AVec);
+ Vector = ActivationFunction.Activate(Vector);
MlasStoreFloat16x8(buffer, Vector);
buffer += 8;
n -= 8;
}
if (n >= 4) {
- MLAS_FLOAT16X4 Vector = MlasLoadFloat16x4(buffer);
MLAS_FLOAT16X4 AVec = MlasLoadFloat16x4(addsrc);
+ MLAS_FLOAT16X4 Vector = MlasLoadFloat16x4(buffer);
addsrc += 4;
- Vector = ActivationFunction.Activate(Vector);
Vector = MlasAddFloat16x4(Vector, AVec);
+ Vector = ActivationFunction.Activate(Vector);
MlasStoreFloat16x4(buffer, Vector);
buffer += 4;
n -= 4;
}
if (n > 0) {
- MLAS_FLOAT16X4 buf;
- std::memcpy(&buf, buffer, n * sizeof(_mlas_fp16_));
MLAS_FLOAT16X4 addbuf;
+ MLAS_FLOAT16X4 buf;
std::memcpy(&addbuf, addsrc, n * sizeof(_mlas_fp16_));
- MLAS_FLOAT16X4 res = ActivationFunction.Activate(buf);
- res = MlasAddFloat16x4(res, addbuf);
- MlasStorePartialFloat16x4(buffer, res, n);
+ std::memcpy(&buf, buffer, n * sizeof(_mlas_fp16_));
+ buf = MlasAddFloat16x4(buf, addbuf);
+ buf = ActivationFunction.Activate(buf);
+ MlasStorePartialFloat16x4(buffer, buf, n);
}
CRow += ldc;
@@ -858,8 +858,6 @@ MLAS_HALF_GEMM_ACTIVATION_PROCESSOR::Process(
) const
{
std::vector buffer(CountM*CountN);
- MLAS_HALF_GEMM_2FLOAT_PROCESSOR proc(this->Activation_, buffer.data(), CountN);
- proc.Process(C, StartM, StartN, CountM, CountN, ldc);
_mlas_fp16_* Output = reinterpret_cast<_mlas_fp16_*>(C);
auto* CRow = buffer.data();
@@ -876,6 +874,8 @@ MLAS_HALF_GEMM_ACTIVATION_PROCESSOR::Process(
}
CAdd += ldc;
}
+ MlasActivation(&this->Activation_, CRow, nullptr, 1, CountN, CountN);
+
CvtFloat2Half(Output, CRow, CountN);
CRow += CountN;
Output += ldc;
diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h
index ea620877ed..7155a9dd27 100644
--- a/onnxruntime/core/mlas/lib/mlasi.h
+++ b/onnxruntime/core/mlas/lib/mlasi.h
@@ -782,7 +782,7 @@ extern "C" {
// value.
//
-#define MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT 32
+#define MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT 64
//
// Define the target number of per-thread multiplies before using another
diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc
index 54fe7c4087..c090ab2a6c 100644
--- a/onnxruntime/core/optimizer/conv_activation_fusion.cc
+++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc
@@ -7,6 +7,7 @@
#include "core/common/inlined_containers.h"
#include "core/framework/tensorprotoutils.h"
+#include "core/mlas/inc/mlas.h"
#include "core/graph/graph_utils.h"
#include "core/graph/node_attr_utils.h"
#include "core/optimizer/utils.h"
@@ -49,18 +50,30 @@ bool ConvFusionDataTypeCheck(const Node& conv_node) {
// Assess the support level for the other compatible EPs and if they also
// only support float, remove the EP check altogether.
const std::string_view node_ep = conv_node.GetExecutionProviderType();
- if (node_ep == kCudaExecutionProvider || node_ep == kCpuExecutionProvider) {
+ if (node_ep == kCudaExecutionProvider) {
if (!HasElementDataType(*conv_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) {
return false;
}
}
+ if (node_ep == kCpuExecutionProvider) {
+#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
+ if (!HasElementDataType(*conv_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT) &&
+ !HasElementDataType(*conv_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)) {
+ return false;
+ }
+#else
+ if (!HasElementDataType(*conv_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) {
+ return false;
+ }
+#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED
+ }
return true;
}
-class ConvActivation : public NodeSelector {
+class ConvActivationSelector : public NodeSelector {
public:
- ConvActivation() = default;
+ ConvActivationSelector() = default;
std::optional Select(const GraphViewer& graph_viewer, const Node& node) const override {
const std::string_view node_ep = node.GetExecutionProviderType();
@@ -159,7 +172,7 @@ class ConvAddRelu : public NodeSelector {
namespace actions {
using NTO = NodesToOptimize;
-class FuseConvActivation : public ReplaceWithNew {
+class FuseConvActivationAction : public ReplaceWithNew {
private:
std::string OpType(const RuntimeState&) const override { return "FusedConv"; }
@@ -245,9 +258,9 @@ class FuseConvAddRelu : public ReplaceWithNew {
void RegisterConvActivationFusionRules(SelectorActionRegistry& registry) {
const auto name = "ConvAct";
- auto action = std::make_unique();
+ auto action = std::make_unique();
#if !defined(ORT_MINIMAL_BUILD)
- auto selector = std::make_unique();
+ auto selector = std::make_unique();
registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}},
std::move(selector), std::move(action));
#else
diff --git a/onnxruntime/core/optimizer/conv_add_act_fusion.cc b/onnxruntime/core/optimizer/conv_add_act_fusion.cc
index 5e34bce362..7c8bfeaec5 100644
--- a/onnxruntime/core/optimizer/conv_add_act_fusion.cc
+++ b/onnxruntime/core/optimizer/conv_add_act_fusion.cc
@@ -40,20 +40,28 @@ const Node* GetLoneConsumerNode(const GraphViewer& graph_viewer, const Node& nod
return &*node.OutputNodesBegin();
}
-class ConvAddActivation : public NodeSelector {
+class ConvAddActivationSelector : public NodeSelector {
public:
- ConvAddActivation() = default;
-
+ ConvAddActivationSelector() = default;
std::optional Select(const GraphViewer& graph_viewer, const Node& node) const override {
const std::string_view node_ep = node.GetExecutionProviderType();
- if (node_ep != kCpuExecutionProvider || !HasElementDataType(*node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) {
+#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
+ if (node_ep != kCpuExecutionProvider ||
+ (!HasElementDataType(*node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT) &&
+ !HasElementDataType(*node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT16))) {
+ return std::nullopt;
+ }
+#else
+ if (node_ep != kCpuExecutionProvider ||
+ !HasElementDataType(*node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) {
return std::nullopt;
}
+#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED
// we can't assign `conv_node` as the producer-node, even it is, because we have to make sure
// 1. Its type is 'conv', 2. it has to satisfy the other requirements,like shape, please refer to SelectConvProducer for more info
const Node* conv_node = nullptr;
const auto* add_node = GetLoneConsumerNode(graph_viewer, node);
- if (!add_node) {
+ if (add_node == nullptr) {
return std::nullopt;
}
// Let's support addition first, leave any-element-wise-op fusion in the future.
@@ -64,13 +72,13 @@ class ConvAddActivation : public NodeSelector {
if (graph_utils::IsSupportedOptypeVersionAndDomain(*add_node, "Add", {7, 13, 14})) {
conv_node = SelectProducerConv(*add_node);
}
- if (!conv_node) {
+ if (conv_node == nullptr) {
return std::nullopt;
}
// GetLoneConsumerNode will ensure outputedge_count is 1
const auto* act_node = GetLoneConsumerNode(graph_viewer, *add_node);
// even the next node is not a activation node, it's also fine.
- if (!act_node) {
+ if (act_node == nullptr) {
// we can't fuse add-activation when add_node has multiple consumer nodes
act_node = nullptr;
} else if (SelectActivation(graph_viewer, *act_node)) {
@@ -82,7 +90,7 @@ class ConvAddActivation : public NodeSelector {
NodesToOptimizeIndicesBuilder builder{};
builder.target_node = conv_node->Index();
builder.output_nodes = {add_node->Index()};
- if (act_node) {
+ if (act_node != nullptr) {
builder.output_nodes.push_back(act_node->Index());
}
return builder.Build();
@@ -167,17 +175,27 @@ class ConvAddActivation : public NodeSelector {
// Check if this is a single use convolution that hasn't already
// been fused with another Add/Sum node. The Add/Sum can also only be
// fused if the convolution isn't itself fused with an activation.
- if ((inputs_node[n]->OpType() == "Conv") && (pre_input_defs_count < 4) && (producer_input_args_count.size() < 4) &&
- (graph_utils::GetNodeAttribute(*inputs_node[n], "activation") == nullptr) && (inputs_node[n]->GetOutputEdgesCount() == 1)) {
- if (pre_input_defs_count < 3) {
- // The optional bias parameter is empty so set to an empty string.
+ if ((inputs_node[n]->OpType() == "Conv") && (pre_input_defs_count < 4) &&
+ (producer_input_args_count.size() < 4) &&
+ (graph_utils::GetNodeAttribute(*inputs_node[n], "activation") == nullptr) &&
+ (inputs_node[n]->GetOutputEdgesCount() == 1)) {
+ if (pre_input_defs_count < 3) { // The optional bias parameter is empty so set to an empty string.
+ // TODO, add a new null arguments for bias
+ continue;
+ }
+ return inputs_node[n];
+ }
+ if (inputs_node[n]->OpType() == "NhwcFusedConv" && (pre_input_defs_count < 4) &&
+ (producer_input_args_count.size() < 5) &&
+ (graph_utils::GetNodeAttribute(*inputs_node[n], "activation") == nullptr) &&
+ (inputs_node[n]->GetOutputEdgesCount() == 1)) {
+ if (pre_input_defs_count < 3) { // The optional bias parameter is empty so set to an empty string.
// TODO, add a new null arguments for bias
continue;
}
return inputs_node[n];
}
}
-
return nullptr;
}
};
@@ -187,9 +205,14 @@ class ConvAddActivation : public NodeSelector {
namespace actions {
using NTO = NodesToOptimize;
-class FuseConvAddActivation : public ReplaceWithNew {
+class FuseConvAddActivationAction : public ReplaceWithNew {
+ public:
+ FuseConvAddActivationAction() = default;
+
private:
- std::string OpType(const RuntimeState&) const override { return "FusedConv"; }
+ std::string OpType(const RuntimeState& runtimeState) const override {
+ return (runtimeState.selected_nodes.Target().OpType() == "Conv") ? "FusedConv" : "NhwcFusedConv";
+ }
std::string Domain(const RuntimeState&) const override { return kMSDomain; }
@@ -262,11 +285,14 @@ class FuseConvAddActivation : public ReplaceWithNew {
} // namespace actions
void RegisterConvAddActivationFusionRules(SelectorActionRegistry& registry) {
- const auto name = "ConvAddAct";
- auto action = std::make_unique();
- auto selector = std::make_unique();
- registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}},
+ auto action = std::make_unique();
+ auto selector = std::make_unique();
+ registry.RegisterSelectorAndAction("ConvAddAct", {{"Conv", {1, 11}}},
std::move(selector), std::move(action));
+ auto action_nhwc = std::make_unique();
+ auto selector_nhwc = std::make_unique();
+ registry.RegisterSelectorAndAction("NhwcFusedConvAct", {{"NhwcFusedConv", {1, 11}}},
+ std::move(selector_nhwc), std::move(action_nhwc));
}
SelectorActionRegistry CreateSelectorActionRegistry() {
diff --git a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc
index 540e0e92d3..e182b6c695 100644
--- a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc
+++ b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc
@@ -93,9 +93,10 @@ static Status MatchAndProcess(
Status status = Status::OK();
do {
- // TODO: for now this just needs to support ONNX ops. If we ever had a transformer that was going to
- // target non-ONNX ops we'd need to rework a few things to include the op domain in the matches
- if (node.Domain() != kOnnxDomain) {
+ // TODO: for now this just needs to support ONNX and Micrsoft Domain ops.
+ // If we ever had a transformer that was going to target non-ONNX ops,
+ // we'd need to rework a few things to include the op domain in the matches
+ if (node.Domain() != kOnnxDomain && node.Domain() != kMSDomain) {
break;
}
diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc
index b9617accd6..022ecc30f3 100644
--- a/onnxruntime/core/platform/windows/env.cc
+++ b/onnxruntime/core/platform/windows/env.cc
@@ -137,12 +137,20 @@ class WindowsThread : public EnvThread {
static unsigned __stdcall ThreadMain(void* param) {
std::unique_ptr p(static_cast(param));
- const ORTCHAR_T* name_prefix =
- (p->name_prefix == nullptr || wcslen(p->name_prefix) == 0) ? L"onnxruntime" : p->name_prefix;
- std::wostringstream oss;
- oss << name_prefix << "-" << p->index;
- // Ignore the error
- (void)SetThreadDescription(GetCurrentThread(), oss.str().c_str());
+ // Not all machines have kernel32.dll and/or SetThreadDescription (e.g. Azure App Service sandbox)
+ // so we need to ensure it's available before calling.
+ HMODULE kernelModule = GetModuleHandle(TEXT("kernel32.dll"));
+ if (kernelModule != nullptr) {
+ auto setThreadDescriptionFn = (SetThreadDescriptionFunc)GetProcAddress(kernelModule, "SetThreadDescription");
+ if (setThreadDescriptionFn != nullptr) {
+ const ORTCHAR_T* name_prefix = (p->name_prefix == nullptr || wcslen(p->name_prefix) == 0) ? L"onnxruntime"
+ : p->name_prefix;
+ std::wostringstream oss;
+ oss << name_prefix << "-" << p->index;
+ // Ignore any errors
+ (void)(setThreadDescriptionFn)(GetCurrentThread(), oss.str().c_str());
+ }
+ }
unsigned ret = 0;
ORT_TRY {
diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc
index d06946579b..6db63752c5 100644
--- a/onnxruntime/core/providers/cann/cann_execution_provider.cc
+++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc
@@ -1511,7 +1511,7 @@ void CANNExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry&
OrtDevice CANNExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const {
if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) {
- return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CANN_PINNED, default_device_.Id());
+ return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CANN_PINNED, 0 /*CPU device id always be 0*/);
}
return default_device_;
}
diff --git a/onnxruntime/core/providers/coreml/builders/helper.cc b/onnxruntime/core/providers/coreml/builders/helper.cc
index f2055aa8e8..d062d59de8 100644
--- a/onnxruntime/core/providers/coreml/builders/helper.cc
+++ b/onnxruntime/core/providers/coreml/builders/helper.cc
@@ -46,6 +46,11 @@ bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const lo
}
bool IsInputSupported(const NodeArg& input, const std::string& parent_name, const logging::Logger& logger) {
+ if (!input.Exists()) {
+ // optional input that is not provided
+ return true;
+ }
+
const auto& input_name = input.Name();
const auto* shape_proto = input.Shape();
// We do not support input with no shape
@@ -86,12 +91,6 @@ std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewe
}
#endif
- const auto& graph_inputs = graph_viewer.GetInputs();
- if (std::any_of(graph_inputs.begin(), graph_inputs.end(),
- [&](const NodeArg* input) { return !IsInputSupported(*input, "graph", logger); })) {
- return supported_nodes;
- }
-
for (const auto& node : graph_viewer.Nodes()) {
const bool supported = IsNodeSupported(node, graph_viewer, logger);
LOGS(logger, VERBOSE) << "Operator type: [" << node.OpType()
diff --git a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc
index d4bf25ab64..e6867f1081 100644
--- a/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc
+++ b/onnxruntime/core/providers/cpu/fp16/fp16_conv.cc
@@ -32,7 +32,7 @@ using ConvPadVector = ConvAttributes::ConvPadVector;
* 2. Activation
* It takes an operator attribute 'activation', which supplies the activation info.
*
- * Add is performed AFTER activation.
+ * Add is performed BEFORE activation.
*
* The implementation supports both NCHW and NHWC. It runs faster with NHWC.
*
@@ -281,12 +281,10 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
if (Y->Shape().Size() == 0) {
return Status::OK();
}
- if (Sum) {
- if (Sum->Shape() != Y->Shape()) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Z shape does not match output shape.",
- " Z: ", Sum->Shape().ToString().c_str(),
- " Output: ", Y->Shape().ToString().c_str());
- }
+ if (Sum && Sum->Shape() != Y->Shape()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Z shape does not match output shape.",
+ " Z: ", Sum->Shape().ToString().c_str(),
+ " Output: ", Y->Shape().ToString().c_str());
}
const int64_t input_image_size = input_shape.Size();
@@ -338,7 +336,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
const auto* Xdata = X->Data();
const auto* Bdata = B != nullptr ? B->Data() : nullptr;
auto* Ydata = Y->MutableData();
- const auto* SumData = Sum != nullptr ? Sum->Data() : nullptr;
+ const auto* sum_data = Sum != nullptr ? Sum->Data() : nullptr;
BufferUniquePtr transpose_input_buffer;
BufferUniquePtr transpose_output_buffer;
@@ -409,7 +407,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
for (int64_t image_id = 0; image_id < N; ++image_id) {
const auto* input_data = Xdata;
auto* output_data = Ydata;
- const auto* add_src = SumData;
+ const auto* add_src = sum_data;
if (!channels_last_) {
// Transpose the input from channels first (CHW) to channels last (HWC).
@@ -478,7 +476,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
static_cast(M),
static_cast(output_count),
static_cast(kernel_size),
- &act);
+ (!channels_last_ && sum_data) ? nullptr : &act);
} else {
for (int64_t group_id = 0; group_id < group_count; ++group_id) {
// Prepare the im2col transformation or use the input buffer directly for
@@ -554,7 +552,7 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
gemm_params.C = worker_output + group_id * group_output_channels;
gemm_params.ldc = static_cast(M);
gemm_params.Bias = Bdata;
- gemm_params.OutputProcessor = &act; // process fused activation and add
+ gemm_params.OutputProcessor = (!channels_last_ && sum_data) ? nullptr : &act; // process fused activation and add
MlasHalfGemmBatch(
static_cast(output_count),
@@ -574,10 +572,8 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
Ydata,
static_cast(output_image_size),
static_cast(M));
- if (SumData != nullptr) {
- MLAS_ACTIVATION activation;
- activation.ActivationKind = MlasIdentityActivation;
- MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc(activation, SumData);
+ if (sum_data != nullptr) {
+ MLAS_HALF_GEMM_ACTIVATION_PROCESSOR proc(activation_, sum_data);
proc.Process(Ydata, 0, 0, static_cast(M),
static_cast(output_image_size),
static_cast(output_image_size));
@@ -586,8 +582,8 @@ Status FusedConvFp16::Compute(OpKernelContext* context) const {
Xdata += X_offset;
Ydata += Y_offset;
- if (SumData != nullptr) {
- SumData += Y_offset;
+ if (sum_data != nullptr) {
+ sum_data += Y_offset;
}
}
diff --git a/onnxruntime/core/providers/cpu/math/matmul_helper.h b/onnxruntime/core/providers/cpu/math/matmul_helper.h
index 7c86f11b97..d7275ee324 100644
--- a/onnxruntime/core/providers/cpu/math/matmul_helper.h
+++ b/onnxruntime/core/providers/cpu/math/matmul_helper.h
@@ -371,9 +371,24 @@ class MatMulComputeHelper {
return right_zp_offsets_;
}
+ static bool IsAligned(const std::vector& offsets) {
+ constexpr size_t alignment = 16;
+ const auto len = offsets.size();
+ for (size_t i = 0; i < len; i++) {
+ if ((offsets[i] % alignment) != 0) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ bool IsBatchedGemmAligned() const {
+ return IsAligned(left_offsets_) && IsAligned(right_offsets_) && IsAligned(output_offsets_);
+ }
+
template
static void OffsetToArrays(T* p, const std::vector& offsets, gsl::span arrays) {
- auto len = offsets.size();
+ const auto len = offsets.size();
ORT_ENFORCE(arrays.size() == len);
for (size_t i = 0; i < len; i++) {
arrays[i] = p + offsets[i];
@@ -382,7 +397,7 @@ class MatMulComputeHelper {
template
static void OffsetToArrays(const T* p, const std::vector& offsets, gsl::span arrays) {
- auto len = offsets.size();
+ const auto len = offsets.size();
ORT_ENFORCE(arrays.size() == len);
for (size_t i = 0; i < len; i++) {
arrays[i] = p + offsets[i];
diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h b/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h
index a4d67ec63f..984f4795ce 100644
--- a/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h
+++ b/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h
@@ -44,14 +44,16 @@ struct ConvTransposeAttributes : public ConvAttributes {
};
Status PrepareForCompute(OpKernelContext* context, bool has_bias, Prepare& p,
- bool dynamic_padding = false, const TensorShape* filter_shape = nullptr) const {
+ bool dynamic_padding = false, const TensorShape* filter_shape = nullptr, bool is_quant = false) const {
const Tensor* X = context->Input(0);
- const Tensor* F = (filter_shape != nullptr) ? nullptr : context->Input(1);
+ const Tensor* F = (filter_shape != nullptr) ? nullptr : context->Input(is_quant ? 3 : 1);
const TensorShape& F_Shape = (filter_shape != nullptr) ? *filter_shape : F->Shape();
+ if (dynamic_padding && is_quant) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "dynamic padding is not supported for quantized conv tranpose.");
+ }
const Tensor* Pads = dynamic_padding ? context->Input(2) : nullptr;
- const Tensor* B = has_bias ? (dynamic_padding ? context->Input(3) : context->Input(2)) : nullptr;
+ const Tensor* B = has_bias ? (dynamic_padding ? context->Input(is_quant ? 9 : 3) : context->Input(is_quant ? 8 : 2)) : nullptr;
TensorShape input_shape = X->Shape().Slice(2);
-
const int64_t num_input_channels = X->Shape()[1];
const int64_t N = X->Shape()[0];
const int64_t num_output_channels_multiplier = F_Shape[1];
diff --git a/onnxruntime/core/providers/cpu/quantization/qlinearconvtranspose.cc b/onnxruntime/core/providers/cpu/quantization/qlinearconvtranspose.cc
new file mode 100644
index 0000000000..513f72030b
--- /dev/null
+++ b/onnxruntime/core/providers/cpu/quantization/qlinearconvtranspose.cc
@@ -0,0 +1,254 @@
+/**
+* Copyright (c) 2014-present, Quadric, Inc.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+/* Modifications Copyright (c) Microsoft. */
+
+#include "core/mlas/inc/mlas.h"
+#include "core/common/safeint.h"
+#include "core/util/math.h"
+#include "core/util/math_cpuonly.h"
+
+#include "core/common/inlined_containers_fwd.h"
+#include "core/framework/transpose_helper.h"
+#include "core/providers/utils.h"
+#include "core/framework/tensorprotoutils.h"
+
+#include "core/providers/cpu/nn/conv_transpose_attributes.h"
+
+namespace onnxruntime {
+
+template
+class QLinearConvTranspose : public OpKernel {
+ public:
+ explicit QLinearConvTranspose(const OpKernelInfo& info) : OpKernel(info), conv_transpose_attrs_(info) {
+ }
+
+ Status Compute(OpKernelContext* context) const override;
+
+ Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ /*out*/ bool& is_packed,
+ /*out*/ PrePackedWeights* prepacked_weights) override;
+
+ Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers,
+ int input_idx,
+ /*out*/ bool& used_shared_buffers) override;
+
+ protected:
+ Status DoConvTranspose(OpKernelContext* context) const;
+
+ private:
+ static float ComputeOutputScale(OpKernelContext* context) {
+ const Tensor* X_scale = context->Input(InputTensors::IN_X_SCALE);
+ const Tensor* W_scale = context->Input(InputTensors::IN_W_SCALE);
+ const Tensor* Y_scale = context->Input(InputTensors::IN_Y_SCALE);
+ ORT_ENFORCE(IsScalarOr1ElementVector(X_scale),
+ "QLinearConv : input scale must be a scalar or 1D tensor of size 1");
+ ORT_ENFORCE(IsScalarOr1ElementVector(Y_scale),
+ "QLinearConv : result scale must be a scalar or 1D tensor of size 1");
+ ORT_ENFORCE(IsScalarOr1ElementVector(W_scale),
+ "QLinearConv : filter scale must be a scalar or 1D tensor of size 1");
+
+ auto X_scale_value = *(X_scale->Data());
+ auto Y_scale_value = *(Y_scale->Data());
+ auto W_scale_value = *(W_scale->Data());
+
+ return X_scale_value * W_scale_value / Y_scale_value;
+ }
+
+ enum InputTensors : int {
+ IN_X = 0,
+ IN_X_SCALE = 1,
+ IN_X_ZERO_POINT = 2,
+ IN_W = 3,
+ IN_W_SCALE = 4,
+ IN_W_ZERO_POINT = 5,
+ IN_Y_SCALE = 6,
+ IN_Y_ZERO_POINT = 7,
+ IN_BIAS = 8
+
+ };
+ enum OutputTensors : int {
+ OUT_Y = 0
+ };
+
+ ConvTransposeAttributes conv_transpose_attrs_;
+
+ // for pre-packing usage
+ TensorShape filter_shape_;
+ BufferUniquePtr transposed_filter_;
+};
+
+template
+Status QLinearConvTranspose::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/,
+ /*out*/ bool& is_packed,
+ /*out*/ PrePackedWeights* /*prepacked_weights*/
+) {
+ is_packed = false;
+ return Status::OK();
+}
+
+template
+Status QLinearConvTranspose::UseSharedPrePackedBuffers(std::vector& /*prepacked_buffers*/,
+ int /*input_idx*/,
+ /*out*/ bool& used_shared_buffers) {
+ used_shared_buffers = false;
+ return Status::OK();
+}
+
+template
+Status QLinearConvTranspose::Compute(OpKernelContext* context) const {
+ return QLinearConvTranspose::DoConvTranspose(context);
+}
+
+template
+Status QLinearConvTranspose::DoConvTranspose(OpKernelContext* context) const {
+ typedef int32_t ActType;
+ size_t num_inputs = OpKernel::Node().InputDefs().size();
+ ConvTransposeAttributes::Prepare p;
+ bool has_bias = num_inputs == 9;
+ ORT_RETURN_IF_ERROR(conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, false, nullptr, true));
+
+ const int64_t input_image_size = p.input_shape.Size();
+
+ // Bail out early if one of the dimensions is zero.
+ if (p.Y->Shape().Size() == 0) {
+ return Status::OK();
+ }
+
+ // Quantization parameters, only support symmetric quant for now
+ float scale_value = ComputeOutputScale(context);
+
+ const int64_t X_offset = p.num_input_channels / conv_transpose_attrs_.group * input_image_size;
+ const int64_t Y_offset = p.Y->Shape().Size() / p.Y->Shape()[0] / conv_transpose_attrs_.group;
+ const int64_t W_offset = p.F->Shape().Size() / conv_transpose_attrs_.group;
+ const int64_t kernel_size = TensorShape(p.kernel_shape).Size();
+ const int64_t kernel_dim = p.num_output_channels / conv_transpose_attrs_.group * kernel_size;
+ const int64_t output_size = (p.Y->Shape().Slice(2)).Size();
+
+ AllocatorPtr alloc;
+ ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
+
+ const int64_t col_buffer_size = kernel_dim * p.input_shape.Size();
+ auto col_data = alloc->Alloc(SafeInt(sizeof(ActType)) * col_buffer_size);
+ BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc));
+
+ // Pre-transpose the weight matrix because MatMul does not take transpose params
+ const int64_t trans_filt_size = p.num_input_channels / conv_transpose_attrs_.group * kernel_dim;
+ auto trans_filt_data = alloc->Alloc(SafeInt(sizeof(T)) * trans_filt_size);
+ BufferUniquePtr trans_filt(trans_filt_data, BufferDeleter(alloc));
+ ActType* col_buffer_data = static_cast(col_buffer.get());
+
+ const T* Xdata = p.X->Data();
+ const T* filter_data = p.F->Data();
+ T* Ydata = p.Y->MutableData();
+ TensorShape output_shape = p.Y->Shape().Slice(2);
+ MlasTranspose(
+ filter_data,
+ static_cast(trans_filt.get()),
+ p.num_input_channels / conv_transpose_attrs_.group,
+ kernel_dim);
+
+ //TODO: add support for assymmetric quantization and using int8 MlassGemm. This will require offseting the
+ //input data to an uint8 range and passing a zero point parameter.
+ //For now compute the GEMM in int32, as MlassGemm requires uint8 input and MlasSymmQgemmBatch is ARM only
+ auto inp_i32_buffer = alloc->Alloc(SafeInt(sizeof(ActType)) * p.X->Shape().Size());
+ BufferUniquePtr inp_i32(inp_i32_buffer, BufferDeleter(alloc));
+ ActType* inp_i32_data = static_cast(inp_i32.get());
+ for (std::int64_t i = 0; i < p.X->Shape().Size(); i++) {
+ inp_i32_data[i] = Xdata[i];
+ }
+ auto wt_i32_buffer = alloc->Alloc(SafeInt(sizeof(ActType)) * p.F->Shape().Size());
+ BufferUniquePtr wt_i32(wt_i32_buffer, BufferDeleter(alloc));
+ ActType* wt_i32_data = static_cast(wt_i32.get());
+ for (std::int64_t i = 0; i < p.F->Shape().Size(); i++) {
+ wt_i32_data[i] = static_cast(trans_filt.get())[i];
+ }
+ auto out_i32_buffer = alloc->Alloc(SafeInt(sizeof(ActType)) * p.Y->Shape().Size());
+ BufferUniquePtr out_i32(out_i32_buffer, BufferDeleter(std::move(alloc)));
+ ActType* out_i32_data = static_cast(out_i32.get());
+
+ for (auto image_id = 0; image_id < p.N; ++image_id) {
+ for (int group_id = 0; group_id < conv_transpose_attrs_.group; ++group_id) {
+ math::MatMul(kernel_dim, input_image_size, p.num_input_channels / conv_transpose_attrs_.group,
+ wt_i32_data + group_id * W_offset,
+ inp_i32_data + group_id * X_offset,
+ col_buffer_data,
+ nullptr);
+
+ math::Col2im(
+ reinterpret_cast(col_buffer_data),
+ p.num_output_channels / conv_transpose_attrs_.group,
+ p.Y->Shape()[2],
+ p.Y->Shape()[3],
+ p.kernel_shape[0],
+ p.kernel_shape[1],
+ p.dilations[0],
+ p.dilations[1],
+ p.pads[0],
+ p.pads[1],
+ p.pads[2],
+ p.pads[3],
+ p.strides[0],
+ p.strides[1],
+ reinterpret_cast(out_i32_data + group_id * Y_offset),
+ &CPUMathUtil::Instance());
+ }
+
+ if (p.B != nullptr) {
+ auto out_i32_matrix = EigenMatrixMap(out_i32_data, output_size, p.num_output_channels);
+ auto Bvec = ConstEigenVectorMap(p.B->Data(), p.num_output_channels);
+ out_i32_matrix.rowwise() += Bvec.transpose();
+ }
+
+ MlasRequantizeOutput(out_i32_data,
+ p.Y->Shape()[2] * p.Y->Shape()[3],
+ Ydata,
+ p.Y->Shape()[2] * p.Y->Shape()[3],
+ nullptr,
+ &scale_value,
+ false,
+ (T)0,
+ 0,
+ 0,
+ p.num_output_channels,
+ p.Y->Shape()[2] * p.Y->Shape()[3]);
+ Xdata += X_offset * conv_transpose_attrs_.group;
+ Ydata += Y_offset * conv_transpose_attrs_.group;
+ }
+
+ return Status::OK();
+}
+
+#ifndef DISABLE_CONTRIB_OPS
+namespace contrib {
+
+// Register the operator with int8 inputs and weights
+ONNX_OPERATOR_KERNEL_EX(
+ QLinearConvTranspose,
+ kMSDomain,
+ 1,
+ //int8_t,
+ kCpuExecutionProvider,
+ KernelDefBuilder()
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T3", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T4", DataTypeImpl::GetTensorType()),
+ QLinearConvTranspose);
+
+} // namespace contrib
+#endif
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index 0d3c54d676..2cd0467c77 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -2532,9 +2532,8 @@ void CUDAExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry&
}
OrtDevice CUDAExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const {
- if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) {
- return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, default_device_.Id());
- }
+ if (mem_type == OrtMemTypeCPUInput) return OrtDevice();
+ if (mem_type == OrtMemTypeCPUOutput) return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, 0 /*CPU device id always be 0*/);
return default_device_;
}
diff --git a/onnxruntime/core/providers/cuda/math/matmul.cc b/onnxruntime/core/providers/cuda/math/matmul.cc
index 9d7e673109..5a813a98a6 100644
--- a/onnxruntime/core/providers/cuda/math/matmul.cc
+++ b/onnxruntime/core/providers/cuda/math/matmul.cc
@@ -180,6 +180,14 @@ Status MatMul::ComputeInternal(OpKernelContext* ctx) const {
ORT_RETURN_IF_ERROR(right_arrays.CopyToGpu(ctx->GetComputeStream()));
ORT_RETURN_IF_ERROR(output_arrays.CopyToGpu(ctx->GetComputeStream()));
+ // TF32 provides a huge performance gain for training and inference while preserving FP32 levels of accuracy.
+ // It requires Ampere or newer GPU, and pointers of matrics shall be aligned (ideal alignment is 16-byte).
+ // Assume that start memory of input/output tensor is aligned, we only check offsets of sub-matrix per batch here.
+ cublasMath_t mode = (std::is_same::value && device_prop.major >= 8 && helper.IsBatchedGemmAligned())
+ ? CUBLAS_TF32_TENSOR_OP_MATH
+ : CUBLAS_DEFAULT_MATH;
+ CublasMathModeSetter math_mode_setter(device_prop, GetCublasHandle(ctx), mode);
+
// note that onnxruntime OrtValue is row major, while cublas is column major,
// so swap left/right operands
CUBLAS_RETURN_IF_ERROR(cublasGemmBatchedHelper(
diff --git a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h
index 8bf429949e..8424d14d39 100644
--- a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h
+++ b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h
@@ -185,13 +185,7 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle,
const float* beta,
float* Carray[], int ldc,
int batch_count,
- const cudaDeviceProp& prop) {
-#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
- onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, CUBLAS_TF32_TENSOR_OP_MATH);
-#else
- ORT_UNUSED_PARAMETER(prop);
-#endif
-
+ const cudaDeviceProp&) {
return cublasSgemmBatched(handle,
transa,
transb,
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp
index f5abeb5d7b..aae73dca46 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp
@@ -244,17 +244,12 @@ namespace Dml
void* CPUAllocator::Alloc(size_t size)
{
- if (size <= 0)
- {
- return nullptr;
- }
- void* p = malloc(size);
- return p;
+ return onnxruntime::AllocatorDefaultAlloc(size);
}
void CPUAllocator::Free(void* p)
{
- free(p);
+ return onnxruntime::AllocatorDefaultFree(p);
}
} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp
index 55322bea6b..2f1890bb0b 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp
@@ -210,6 +210,8 @@ namespace Dml
m_cpuOutputAllocator = std::make_shared(OrtMemType::OrtMemTypeCPUOutput);
CreateDmlKernelRegistry(&m_kernelRegistry, &m_internalRegInfoMap);
+
+ m_lastUploadFlushTime = std::chrono::steady_clock::now();
}
HRESULT __stdcall ExecutionProviderImpl::GetD3DDevice(_COM_Outptr_ ID3D12Device** d3dDevice) const noexcept
@@ -447,6 +449,7 @@ namespace Dml
const auto dstState = D3D12_RESOURCE_STATE_UNORDERED_ACCESS; // GPU resources are always kept in UAV state
m_uploadHeap->BeginUploadToGpu(dstData, dstOffset, dstState, AsByteSpan(srcData, dataSizeInBytes));
+ FlushUploadsIfReady();
}
else if (!src->IsCpuData() && dst->IsCpuData())
{
@@ -566,12 +569,23 @@ namespace Dml
assert(!m_closed);
m_uploadHeap->BeginUploadToGpu(dstData, 0, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, AsByteSpan(srcData, static_cast(srcDataSize)));
+ FlushUploadsIfReady();
return S_OK;
}
ORT_CATCH_RETURN
}
+ void ExecutionProviderImpl::FlushUploadsIfReady() const
+ {
+ // Periodically flush uploads to make sure the GPU is not idle for too long
+ if (std::chrono::steady_clock::now() - m_lastUploadFlushTime > m_batchFlushInterval)
+ {
+ Flush();
+ m_lastUploadFlushTime = std::chrono::steady_clock::now();
+ }
+ }
+
uint32_t ExecutionProviderImpl::GetSupportedDeviceDataTypeMask() const
{
// The DML provider registers all supported kernels up-front regardless of actual device capability,
@@ -661,9 +675,10 @@ namespace Dml
bool IsCustomOpShader(const onnxruntime::Node& node)
{
- auto custom_ops = std::array{
+ auto custom_ops = std::array{
"DFT",
- "STFT"
+ "STFT",
+ "GridSample"
};
for (auto& custom_op : custom_ops)
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h
index b9ac772095..04f3420f96 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h
@@ -180,6 +180,8 @@ namespace Dml
uint32_t supportedDeviceDataTypeMask // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
) const;
+ void FlushUploadsIfReady() const;
+
ComPtr m_d3d12Device;
ComPtr m_dmlDevice;
bool m_isMcdmDevice = false;
@@ -195,6 +197,8 @@ namespace Dml
std::shared_ptr m_internalRegInfoMap;
mutable uint64_t m_partitionKernelPrefixVal = 0;
bool m_closed = false;
+ mutable std::chrono::time_point m_lastUploadFlushTime;
+ static constexpr std::chrono::milliseconds m_batchFlushInterval = std::chrono::milliseconds(10);
};
class DataTransfer : public onnxruntime::IDataTransfer
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiHelpers.h
index 76ca37bd05..9a1c23093f 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiHelpers.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiHelpers.h
@@ -183,13 +183,13 @@ class StackAllocator
static T RoundUpToMultiple(T value, T multiple)
{
static_assert(std::is_integral_v);
-
+
T remainder = value % multiple;
if (remainder != 0)
{
value += multiple - remainder;
}
-
+
return value;
}
@@ -231,4 +231,4 @@ class StackAllocator
// allocated memory if the fixed stack array is exhausted.
FixedBucket m_fixed;
std::deque m_dynamic;
-};
\ No newline at end of file
+};
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
index 8b5cf36937..c75b662af7 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
@@ -1155,6 +1155,11 @@ struct OperatorDescTraits
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_GELU;
};
+template <>
+struct OperatorDescTraits
+{
+ static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MULTIHEAD_ATTENTION;
+};
template
struct OperatorTypeTraits
@@ -2139,14 +2144,20 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_GELU>
using DescType = DML_ACTIVATION_GELU_OPERATOR_DESC;
};
+template <>
+struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MULTIHEAD_ATTENTION>
+{
+ using DescType = DML_MULTIHEAD_ATTENTION_OPERATOR_DESC;
+};
+
// Calls a visitor functor, supplying an empty operator desc corresponding to the given DML_OPERATOR_TYPE as
// the first argument.
-//
+//
// For example:
// Visit(DML_OPERATOR_ELEMENT_WISE_IDENTITY, [](auto tag) {
// using T = decltype(tag); // T is one of the DML_*_OPERATOR_DESC structs
// });
-//
+//
#pragma warning(push)
#pragma warning(disable:4702)
template
@@ -2432,6 +2443,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward(visitor), DML_RESAMPLE_GRAD1_OPERATOR_DESC{}, std::forward(args)...);
case DML_OPERATOR_DIAGONAL_MATRIX1:
return std::invoke(std::forward(visitor), DML_DIAGONAL_MATRIX1_OPERATOR_DESC{}, std::forward(args)...);
+ case DML_OPERATOR_MULTIHEAD_ATTENTION:
+ return std::invoke(std::forward(visitor), DML_MULTIHEAD_ATTENTION_OPERATOR_DESC{}, std::forward(args)...);
case DML_OPERATOR_ACTIVATION_ELU:
return std::invoke(std::forward(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward(args)...);
case DML_OPERATOR_ACTIVATION_CELU:
@@ -2633,6 +2646,7 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
case DML_OPERATOR_RESAMPLE2: return "DML_OPERATOR_RESAMPLE2";
case DML_OPERATOR_RESAMPLE_GRAD1: return "DML_OPERATOR_RESAMPLE_GRAD1";
case DML_OPERATOR_DIAGONAL_MATRIX1: return "DML_OPERATOR_DIAGONAL_MATRIX1";
+ case DML_OPERATOR_MULTIHEAD_ATTENTION: return "DML_OPERATOR_MULTIHEAD_ATTENTION";
default:
assert(false);
return "";
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h
index 42de619a87..1ebd52d4ed 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h
@@ -2302,6 +2302,35 @@ constexpr DML_OPERATOR_SCHEMA DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA{
DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA_FIELDS,
};
+constexpr DML_SCHEMA_FIELD DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA_FIELDS[18] {
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "QueryTensor", true },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "KeyTensor", true },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ValueTensor", true },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "StackedQueryKeyTensor", true },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "StackedKeyValueTensor", true },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "StackedQueryKeyValueTensor", true },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "MaskTensor", true },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "RelativePositionBiasTensor", true },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "PastKeyTensor", true },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "PastValueTensor", true },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputPresentKeyTensor", true },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputPresentValueTensor", true },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Scale", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "MaskFilterValue", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "HeadCount", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MaskType", false },
+};
+
+constexpr DML_OPERATOR_SCHEMA DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA {
+ "DML_OPERATOR_MULTIHEAD_ATTENTION",
+ DML_OPERATOR_MULTIHEAD_ATTENTION,
+ DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
+ 18,
+ DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA_FIELDS,
+};
+
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLX.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLX.h
index f7feb5ff21..7e050eef50 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLX.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLX.h
@@ -202,7 +202,7 @@ namespace dml
};
}
-#if DMLX_USE_ABSEIL
+#if DMLX_USE_ABSEIL
template
using Optional = absl::optional;
@@ -231,7 +231,7 @@ namespace dml
#elif DMLX_USE_GSL
template
using Span = gsl::span;
- #else
+ #else
template
using Span = dml::detail::span;
#endif
@@ -245,11 +245,11 @@ namespace dml
#define DMLX_THROW(_hr) THROW_HR(_hr)
#else
#define DMLX_THROW_IF_FAILED(_hr) if (FAILED(_hr)) { throw std::runtime_error(#_hr); }
- #define DMLX_THROW(_hr) throw std::runtime_error(#_hr);
+ #define DMLX_THROW(_hr) throw std::runtime_error(#_hr);
#endif
#else
#define DMLX_THROW_IF_FAILED(_hr) if (FAILED(_hr)) { std::abort(); }
- #define DMLX_THROW(_hr) { std::abort(); }
+ #define DMLX_THROW(_hr) { std::abort(); }
#endif
class Graph;
@@ -307,7 +307,7 @@ namespace dml
// (0, 2, ..., n, 1). This is often referred to as "NHWC" or "interleaved channel" layout. This is useful,
// for example, when applied to 2D Convolution to produce outputs in an NHWC layout (as opposed to NCHW, which
// is the DirectML default for 2D Convolution).
- //
+ //
// Examples of the transposes produced by this policy:
// NCW -> NWC
// NCHW -> NHWC
@@ -713,7 +713,7 @@ namespace dml
// Represents an activation to be fused with an existing operator. The meaning of param1 and param2 depend on the
// activation to be fused.
- //
+ //
// For HARD_SIGMOID, LINEAR, PARAMETRIC_SOFTPLUS, and SCALED_TANH: param1 = Alpha and param2 = Beta
// For ELU, LEAKY_RELU, THRESHOLDED_RELU, and CELU: param1 = Alpha. param2 is unused.
// For SCALED_ELU, param1 = Alpha and param2 = Gamma.
@@ -1858,13 +1858,13 @@ namespace dml
}
// Helper for setting parameters for the Convolution operator. Sample usage:
- //
+ //
// auto conv = dml::ConvolutionBuilder(...)
// .StartPadding(...)
// .EndPadding(...)
// .Strides(...)
// .Build();
- //
+ //
// Parameters left unspecified will be defaulted with the same values as dml::Convolution().
class ConvolutionBuilder
{
@@ -2114,9 +2114,9 @@ namespace dml
return output;
}
- //
+ //
// TODO: LpPooling
- //
+ //
// ---------------------------------------------------------------------------------------------------------------
@@ -2203,13 +2203,13 @@ namespace dml
}
// Helper for setting parameters for the MaxPooling operator. Sample usage:
- //
+ //
// auto [out, outIndices] = dml::MaxPoolingBuilder(...)
// .StartPadding(...)
// .EndPadding(...)
// .OutputIndices(...)
// .Build();
- //
+ //
// Parameters left unspecified will be defaulted with the same values as dml::MaxPooling().
class MaxPoolingBuilder
{
@@ -2251,13 +2251,13 @@ namespace dml
// ---------------------------------------------------------------------------------------------------------------
- //
+ //
// TODO: MaxUnpooling
- //
+ //
- //
+ //
// TODO: ROIPooling
- //
+ //
inline Expression Slice(
Expression input,
@@ -2683,7 +2683,7 @@ namespace dml
{
detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder();
TensorDesc inputTensor = input.Impl()->GetOutputDesc();
-
+
assert(inputTensor.sizes.size() == 4);
dml::TensorDesc::Dimensions outputSizes = {
@@ -2691,7 +2691,7 @@ namespace dml
inputTensor.sizes[1] * blockSize * blockSize,
inputTensor.sizes[2] / blockSize,
inputTensor.sizes[3] / blockSize
- };
+ };
TensorDesc outputTensor(inputTensor.dataType, outputSizes, builder->GetTensorPolicy());
@@ -2715,7 +2715,7 @@ namespace dml
{
detail::GraphBuilder* builder = input.Impl()->GetGraphBuilder();
TensorDesc inputTensor = input.Impl()->GetOutputDesc();
-
+
assert(inputTensor.sizes.size() == 4);
dml::TensorDesc::Dimensions outputSizes = {
@@ -2771,7 +2771,7 @@ namespace dml
struct TopKOutputs
{
Expression value;
- Expression index;
+ Expression index;
};
inline TopKOutputs TopK(Expression input, uint32_t axis, uint32_t k, DML_AXIS_DIRECTION axisDirection)
@@ -2909,14 +2909,14 @@ namespace dml
desc.VarianceTensor = varianceTensor.AsPtr();
desc.ScaleTensor = scaleTensor.AsPtr();
desc.Epsilon = epsilon;
-
+
desc.OutputGradientTensor = outputGradientTensor.AsPtr();
desc.OutputScaleGradientTensor = outputScaleGradientTensor.AsPtr();
desc.OutputBiasGradientTensor = outputBiasGradientTensor.AsPtr();
-
+
dml::detail::NodeOutput* const inputs[] = { input.Impl(), inputGradient.Impl(), mean.Impl(), variance.Impl(), scale.Impl() };
dml::detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_BATCH_NORMALIZATION_GRAD, &desc, inputs);
-
+
BatchNormalizationGradOutputs outputValues;
outputValues.gradient = builder->CreateNodeOutput(node, 0, *desc.OutputGradientTensor);
outputValues.scaleGradient = builder->CreateNodeOutput(node, 1, *desc.OutputScaleGradientTensor);
@@ -2932,7 +2932,7 @@ namespace dml
{
Expression output;
Expression mean;
- Expression variance;
+ Expression variance;
};
inline BatchNormalizationTrainingOutputs BatchNormalizationTraining(
@@ -3005,14 +3005,14 @@ namespace dml
desc.VarianceTensor = varianceTensor.AsPtr();
desc.ScaleTensor = scaleTensor.AsPtr();
desc.Epsilon = epsilon;
-
+
desc.OutputGradientTensor = outputGradientTensor.AsPtr();
desc.OutputScaleGradientTensor = outputScaleGradientTensor.AsPtr();
desc.OutputBiasGradientTensor = outputBiasGradientTensor.AsPtr();
-
+
dml::detail::NodeOutput* const inputs[] = { input.Impl(), inputGradient.Impl(), mean.Impl(), variance.Impl(), scale.Impl() };
dml::detail::NodeID node = builder->CreateOperatorNode(DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD, &desc, inputs);
-
+
BatchNormalizationGradOutputs outputValues;
outputValues.gradient = builder->CreateNodeOutput(node, 0, *desc.OutputGradientTensor);
outputValues.scaleGradient = builder->CreateNodeOutput(node, 1, *desc.OutputScaleGradientTensor);
@@ -3099,17 +3099,17 @@ namespace dml
return output;
}
- //
+ //
// TODO: LpNormalization
- //
+ //
- //
+ //
// TODO: RNN
- //
+ //
- //
+ //
// TODO: LSTM
- //
+ //
enum class GRUOutputOptions
{
@@ -3121,7 +3121,7 @@ namespace dml
struct GRUOutputs
{
Expression sequence;
- Expression single;
+ Expression single;
};
inline GRUOutputs GRU(
@@ -3230,7 +3230,7 @@ namespace dml
return { outputSequenceExpr, outputSingleExpr };
}
- //
+ //
// TODO: DiagonalMatrix
//
@@ -3442,33 +3442,33 @@ namespace dml
return output;
}
- //
+ //
// TODO: MatrixMultiplyInteger
- //
+ //
- //
+ //
// TODO: QuantizedLinearMatrixMultiply
- //
+ //
- //
+ //
// TODO: ConvolutionInteger
- //
+ //
- //
+ //
// TODO: QuantizedLinearConvolution
- //
+ //
- //
+ //
// TODO: ReluGrad
- //
+ //
- //
+ //
// TODO: AveragePoolingGrad
- //
+ //
- //
+ //
// TODO: MaxPoolingGrad
- //
+ //
struct RandomGeneratorOutputs
{
@@ -3496,7 +3496,7 @@ namespace dml
// Input and output state have the same TensorDesc.
desc.OutputStateTensor = inputStateTensor.AsPtr();
}
-
+
RandomGeneratorOutputs out;
detail::NodeOutput* const inputs[] = { inputState.Impl() };
@@ -3537,7 +3537,7 @@ namespace dml
desc.InputTensor = inputTensor.AsPtr();
desc.OutputCountTensor = outputCountTensor.AsPtr();
desc.OutputCoordinatesTensor = outputCoordinatesTensor.AsPtr();
-
+
NonZeroCoordinatesOutputs output;
detail::NodeOutput* const inputs[] = { input.Impl() };
@@ -3640,17 +3640,17 @@ namespace dml
return output;
}
- //
+ //
// TODO: AdamOptimizer
- //
+ //
- //
+ //
// TODO: Argmin
- //
+ //
- //
+ //
// TODO: Argmax
- //
+ //
#if DML_TARGET_VERSION >= 0x4000
@@ -3694,7 +3694,7 @@ namespace dml
desc.ROITensor = roiTensor.AsPtr();
desc.BatchIndicesTensor = batchIndicesTensor.AsPtr();
desc.OutputTensor = outputTensor.AsPtr();
- desc.ReductionFunction = reductionFunction;
+ desc.ReductionFunction = reductionFunction;
desc.InterpolationMode = interpolationMode;
desc.SpatialScaleX = spatialScaleX;
desc.SpatialScaleY = spatialScaleY;
@@ -3763,7 +3763,7 @@ namespace dml
outputGradientTensor = TensorDesc(inputGradientTensor.dataType, outputGradientSizes, builder->GetTensorPolicy());
}
-
+
TensorDesc outputROIGradientTensor = computeOutputROIGradient ? TensorDesc(roiTensor.dataType, roiTensor.sizes, builder->GetTensorPolicy()) : TensorDesc();
assert(!computeOutputROIGradient || outputROIGradientTensor.sizes == roiTensor.sizes);
@@ -3774,7 +3774,7 @@ namespace dml
desc.BatchIndicesTensor = batchIndicesTensor.AsPtr();
desc.OutputGradientTensor = computeOutputGradient ? outputGradientTensor.AsPtr() : nullptr;
desc.OutputROIGradientTensor = computeOutputROIGradient ? outputROIGradientTensor.AsPtr() : nullptr;
- desc.ReductionFunction = reductionFunction;
+ desc.ReductionFunction = reductionFunction;
desc.InterpolationMode = interpolationMode;
desc.SpatialScaleX = spatialScaleX;
desc.SpatialScaleY = spatialScaleY;
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h
index aaf02ca146..833871de0b 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h
@@ -1418,6 +1418,29 @@ inline std::vector GetFields(const DML_DIAGONAL_MATRIX1_OPERATOR_
OperatorField(&DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.DiagonalFillEnd))),
};
}
+inline std::vector GetFields(const DML_MULTIHEAD_ATTENTION_OPERATOR_DESC& desc)
+{
+ return {
+ OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.QueryTensor))),
+ OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.KeyTensor))),
+ OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ValueTensor))),
+ OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.StackedQueryKeyTensor))),
+ OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.StackedKeyValueTensor))),
+ OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StackedQueryKeyValueTensor))),
+ OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.BiasTensor))),
+ OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.MaskTensor))),
+ OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.RelativePositionBiasTensor))),
+ OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.PastKeyTensor))),
+ OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.PastValueTensor))),
+ OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.OutputTensor))),
+ OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.OutputPresentKeyTensor))),
+ OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[13], ToOperatorFieldType(static_cast(desc.OutputPresentValueTensor))),
+ OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[14], ToOperatorFieldType(static_cast(desc.Scale))),
+ OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[15], ToOperatorFieldType(static_cast(desc.MaskFilterValue))),
+ OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[16], ToOperatorFieldType(static_cast(desc.HeadCount))),
+ OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[17], ToOperatorFieldType(static_cast(desc.MaskType))),
+ };
+}
inline std::vector GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc)
{
return {
@@ -1753,6 +1776,7 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_RESAMPLE2: return DML_RESAMPLE2_OPERATOR_SCHEMA;
case DML_OPERATOR_RESAMPLE_GRAD1: return DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA;
case DML_OPERATOR_DIAGONAL_MATRIX1: return DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA;
+ case DML_OPERATOR_MULTIHEAD_ATTENTION: return DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_CELU: return DML_ACTIVATION_CELU_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA;
@@ -2346,6 +2370,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA,
GetFields(*static_cast(opDesc.Desc)));
+ case DML_OPERATOR_MULTIHEAD_ATTENTION:
+ return AbstractOperatorDesc(
+ &DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA,
+ GetFields(*static_cast(opDesc.Desc)));
case DML_OPERATOR_ACTIVATION_ELU:
return AbstractOperatorDesc(
&DML_ACTIVATION_ELU_OPERATOR_SCHEMA,
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h
new file mode 100644
index 0000000000..c63863853f
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h
@@ -0,0 +1,988 @@
+#pragma once
+
+#include "../../../OperatorAuthorHelper/OperatorHelper.h"
+#include "../MLOperatorAuthorImpl.h"
+
+#include "../External/D3DX12/d3dx12.h"
+#include
+
+// NOTE: When this operator's implementation is moved into DML, the associated FP16 fallback
+// should be removed from IsCustomOpShader(...) in
+// onnxruntime\core\providers\dml\DmlExecutionProvider\src\ExecutionProvider.cpp
+
+// The shader headers are produced using "GeneratedShaders/GenerateShaders.bat"
+namespace GridSample_uint16_float
+{
+ #include "GeneratedShaders/grid_sample_uint16_float.h"
+}
+
+namespace GridSample_uint_float
+{
+ #include "GeneratedShaders/grid_sample_uint_float.h"
+}
+
+namespace GridSample_uint64_float
+{
+ #include "GeneratedShaders/grid_sample_uint64_float.h"
+}
+
+namespace GridSample_int16_float
+{
+ #include "GeneratedShaders/grid_sample_int16_float.h"
+}
+
+namespace GridSample_int_float
+{
+ #include "GeneratedShaders/grid_sample_int_float.h"
+}
+
+namespace GridSample_int64_float
+{
+ #include "GeneratedShaders/grid_sample_int64_float.h"
+}
+
+namespace GridSample_fp16_float
+{
+ #include "GeneratedShaders/grid_sample_fp16_float.h"
+}
+
+namespace GridSample_float_float
+{
+ #include "GeneratedShaders/grid_sample_float_float.h"
+}
+
+namespace GridSample_double_float
+{
+ #include "GeneratedShaders/grid_sample_double_float.h"
+}
+
+namespace GridSample_bool_float
+{
+ #include "GeneratedShaders/grid_sample_bool_float.h"
+}
+
+namespace GridSample_uint16_fp16
+{
+ #include "GeneratedShaders/grid_sample_uint16_fp16.h"
+}
+
+namespace GridSample_uint_fp16
+{
+ #include "GeneratedShaders/grid_sample_uint_fp16.h"
+}
+
+namespace GridSample_uint64_fp16
+{
+ #include "GeneratedShaders/grid_sample_uint64_fp16.h"
+}
+
+namespace GridSample_int16_fp16
+{
+ #include "GeneratedShaders/grid_sample_int16_fp16.h"
+}
+
+namespace GridSample_int_fp16
+{
+ #include "GeneratedShaders/grid_sample_int_fp16.h"
+}
+
+namespace GridSample_int64_fp16
+{
+ #include "GeneratedShaders/grid_sample_int64_fp16.h"
+}
+
+namespace GridSample_fp16_fp16
+{
+ #include "GeneratedShaders/grid_sample_fp16_fp16.h"
+}
+
+namespace GridSample_float_fp16
+{
+ #include "GeneratedShaders/grid_sample_float_fp16.h"
+}
+
+namespace GridSample_double_fp16
+{
+ #include "GeneratedShaders/grid_sample_double_fp16.h"
+}
+
+namespace GridSample_bool_fp16
+{
+ #include "GeneratedShaders/grid_sample_bool_fp16.h"
+}
+
+namespace GridSample_uint16_double
+{
+ #include "GeneratedShaders/grid_sample_uint16_double.h"
+}
+
+namespace GridSample_uint_double
+{
+ #include "GeneratedShaders/grid_sample_uint_double.h"
+}
+
+namespace GridSample_uint64_double
+{
+ #include "GeneratedShaders/grid_sample_uint64_double.h"
+}
+
+namespace GridSample_int16_double
+{
+ #include "GeneratedShaders/grid_sample_int16_double.h"
+}
+
+namespace GridSample_int_double
+{
+ #include "GeneratedShaders/grid_sample_int_double.h"
+}
+
+namespace GridSample_int64_double
+{
+ #include "GeneratedShaders/grid_sample_int64_double.h"
+}
+
+namespace GridSample_fp16_double
+{
+ #include "GeneratedShaders/grid_sample_fp16_double.h"
+}
+
+namespace GridSample_float_double
+{
+ #include "GeneratedShaders/grid_sample_float_double.h"
+}
+
+namespace GridSample_double_double
+{
+ #include "GeneratedShaders/grid_sample_double_double.h"
+}
+
+namespace GridSample_bool_double
+{
+ #include "GeneratedShaders/grid_sample_bool_double.h"
+}
+
+
+#include
+#include
+
+#include
+
+using namespace Microsoft::WRL;
+
+enum DmlGridSampleKernelInputIndex : uint32_t
+{
+ Input,
+ Grid,
+};
+
+enum DmlGridSampleMode : uint32_t
+{
+ Bilinear,
+ Nearest,
+ Bicubic,
+};
+
+enum DmlGridSamplePaddingMode : uint32_t
+{
+ Zeros,
+ Border,
+ Reflection
+};
+
+// Helper to derive dimensions and attributes from either the GridSample shape inferrer or the GridSample kernel constructor.
+struct DmlGridSampleParameters
+{
+ uint32_t batchSize = 0;
+ uint32_t channelSize = 0;
+ uint32_t height = 0;
+ uint32_t width = 0;
+ int64_t alignCorners = 0;
+ DmlGridSampleMode mode = DmlGridSampleMode::Bilinear;
+ DmlGridSamplePaddingMode paddingMode = DmlGridSamplePaddingMode::Zeros;
+
+ DML_TENSOR_DATA_TYPE dataType = DML_TENSOR_DATA_TYPE_UNKNOWN;
+
+ DmlGridSampleParameters(){}
+
+ DmlGridSampleParameters(
+ const OperatorHelper::IKernelInformationAdapter& kernelInfo,
+ const OperatorHelper::IShapeInformationAdapter& shapeInfo)
+ {
+ auto& attributes = kernelInfo.GetAttributes();
+
+ alignCorners = attributes.GetOptionalAttribute(AttrName::AlignCorners, 0);
+
+ std::string str_attrib = attributes.GetOptionalAttribute(AttrName::Mode, "bilinear");
+ ML_CHECK_VALID_ARGUMENT(str_attrib == "bilinear" || str_attrib == "nearest" || str_attrib == "bicubic");
+ if (str_attrib == "bilinear")
+ {
+ mode = DmlGridSampleMode::Bilinear;
+ }
+ else if (str_attrib == "nearest")
+ {
+ mode = DmlGridSampleMode::Nearest;
+ }
+ else if (str_attrib == "bicubic")
+ {
+ mode = DmlGridSampleMode::Bicubic;
+ }
+
+ str_attrib = attributes.GetOptionalAttribute(AttrName::PaddingMode, "zeros");
+ ML_CHECK_VALID_ARGUMENT(str_attrib == "zeros" || str_attrib == "border" || str_attrib == "reflection");
+ if (str_attrib == "zeros")
+ {
+ paddingMode = DmlGridSamplePaddingMode::Zeros;
+ }
+ else if (str_attrib == "border")
+ {
+ paddingMode = DmlGridSamplePaddingMode::Border;
+ }
+ else if (str_attrib == "reflection")
+ {
+ paddingMode = DmlGridSamplePaddingMode::Reflection;
+ }
+
+ // input 0: signal (required; tensor)
+ {
+ // Input shape is expected to be [batch_size, channels, height, width]
+ // 4-D tensor of shape (N, C, H_out, W_out) of sampled values.
+ // For integer input types, intermediate values are computed as floating point and cast to integer at the end. uint32_t rank = shapeInfo.GetInputTensorDimensionCount(DmlGridSampleKernelInputIndex::Input);
+ uint32_t rank = shapeInfo.GetInputTensorDimensionCount(DmlGridSampleKernelInputIndex::Input);
+ ML_CHECK_VALID_ARGUMENT(rank == 4, "Input shape must be 4D.");
+
+ auto dims = shapeInfo.GetInputTensorShape(DmlGridSampleKernelInputIndex::Input);
+ assert(dims.size() == rank);
+ this->batchSize = dims[0];
+ this->channelSize = dims[1];
+
+ MLOperatorEdgeDescription edgeDesc = kernelInfo.GetInputEdgeDescription(DmlGridSampleKernelInputIndex::Input);
+
+ assert(edgeDesc.edgeType == MLOperatorEdgeType::Tensor);
+ this->dataType = Dml::GetDmlDataTypeFromMlDataType(edgeDesc.tensorDataType);
+ }
+
+ // input 1: grid (required; tensor)
+ {
+ // Grid shape is expected to be [batch_size, height_out, width_out, 2]
+ uint32_t rank = shapeInfo.GetInputTensorDimensionCount(DmlGridSampleKernelInputIndex::Grid);
+ ML_CHECK_VALID_ARGUMENT(rank == 4, "Input shape must be 4D.");
+
+ auto dims = shapeInfo.GetInputTensorShape(DmlGridSampleKernelInputIndex::Grid);
+ assert(dims.size() == rank);
+ this->height = dims[1];
+ this->width = dims[2];
+ }
+ }
+
+};
+
+namespace GridSampleHelpers
+{
+ // Divides and rounds
+ inline uint32_t CeilDivide(uint32_t dividend, uint32_t divisor)
+ {
+ uint64_t temp = static_cast(dividend) + divisor - 1;
+ return static_cast(temp / divisor);
+ }
+
+ // Gets the next number of elements to dispatch to the GPU within a loop handling a large
+ // total number of tensor elements and threads.
+ void GetNextDispatchSize(
+ uint32_t elementCount,
+ uint32_t elementsPerThread,
+ uint32_t numThreads,
+ _Out_ uint32_t& dispatch,
+ _Out_ uint32_t& pendingElementCount
+ )
+ {
+ // Max threads per workgroup is 2^10 (1024). Max dispatch per dimension is 2^16. Taken together, we can dispatch a maximum of
+ // 2^26 (268,435,456) threads along a single dimension. This should suffice for a majority of the workload. Therefore, even
+ // though it is possible to dispatch up to (2^16)^3 workgroups simultaneously, we stick to the simpler 1D dispatch alternative.
+ assert(numThreads <= D3D12_CS_THREAD_GROUP_MAX_THREADS_PER_GROUP);
+
+ const uint32_t maxThreadsPerDispatch = numThreads * D3D12_CS_DISPATCH_MAX_THREAD_GROUPS_PER_DIMENSION;
+
+ const uint32_t requiredThreadCount = CeilDivide(elementCount, elementsPerThread);
+
+ // Compute max dispatchable elements
+ const uint32_t availableThreadCount = std::min(requiredThreadCount, maxThreadsPerDispatch);
+
+ // Compute required thread group count
+ uint32_t workGroupCount1D = CeilDivide(availableThreadCount, numThreads);
+
+ // Compute min dispatch size
+ dispatch = workGroupCount1D;
+
+ // With the dispatch size computed, compute the dispatched element count
+ const uint32_t dispatchedElementCount = workGroupCount1D * numThreads * elementsPerThread;
+
+ // Update the pending element count
+ pendingElementCount = (dispatchedElementCount < elementCount) ? elementCount - dispatchedElementCount : 0;
+ }
+}
+
+class DmlGridSampleOperator : public WRL::Base
+{
+private:
+ ComPtr m_device;
+ ComPtr m_gridSampleRootSignature;
+ ComPtr m_gridSamplePipelineState;
+ DmlGridSampleParameters m_params = {};
+
+
+ // Allocate temporary buffers if needed
+ struct ResourceDesc
+ {
+ ComPtr Resource;
+ std::array